Fix race errors. (#4184)

* Fix data races.

* Fix block num.

* Pub key lock.

* Fixed race errors.

* Fix flag.

* Fix comments.

* Fix comments.

* Fix type.

* Fix race errors in tests.

Co-authored-by: Konstantin <k.potapov@softpro.com>
pull/4219/head
Konstantin 2 years ago committed by GitHub
parent 4eabc120b1
commit 06de7dcd6c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      Makefile
  2. 14
      consensus/checks.go
  3. 17
      consensus/consensus.go
  4. 30
      consensus/consensus_fbft.go
  5. 15
      consensus/consensus_fbft_test.go
  6. 15
      consensus/consensus_service.go
  7. 2
      consensus/consensus_test.go
  8. 28
      consensus/consensus_v2.go
  9. 2
      consensus/construct.go
  10. 9
      consensus/construct_test.go
  11. 21
      consensus/debug.go
  12. 2
      consensus/double_sign.go
  13. 4
      consensus/leader.go
  14. 6
      consensus/threshold.go
  15. 30
      consensus/validator.go
  16. 14
      consensus/view_change.go
  17. 6
      consensus/view_change_msg.go
  18. 6
      consensus/view_change_test.go
  19. 70
      core/tx_pool_test.go
  20. 5
      core/types/block.go
  21. 30
      internal/utils/keylocker/keylocker.go
  22. 17
      internal/utils/keylocker/keylocker_test.go
  23. 35
      internal/utils/singleton_test.go
  24. 14
      internal/utils/timer.go
  25. 2
      node/api.go
  26. 2
      node/node_newblock.go
  27. 16
      p2p/host.go
  28. 13
      p2p/stream/protocols/sync/chain.go
  29. 19
      p2p/stream/protocols/sync/protocol_test.go
  30. 41
      p2p/utils.go
  31. 37
      p2p/utils_test.go
  32. 2
      scripts/travis_go_checker.sh

@ -156,4 +156,4 @@ go-vet:
go vet ./...
go-test:
go test ./...
go test -vet=all -race ./...

@ -57,9 +57,9 @@ func (consensus *Consensus) senderKeySanityChecks(msg *msg_pb.Message, senderKey
func (consensus *Consensus) isRightBlockNumAndViewID(recvMsg *FBFTMessage,
) bool {
if recvMsg.ViewID != consensus.GetCurBlockViewID() || recvMsg.BlockNum != consensus.blockNum {
if recvMsg.ViewID != consensus.GetCurBlockViewID() || recvMsg.BlockNum != consensus.BlockNum() {
consensus.getLogger().Debug().
Uint64("blockNum", consensus.blockNum).
Uint64("blockNum", consensus.BlockNum()).
Str("recvMsg", recvMsg.String()).
Msg("BlockNum/viewID not match")
return false
@ -103,12 +103,12 @@ func (consensus *Consensus) onAnnounceSanityChecks(recvMsg *FBFTMessage) bool {
}
func (consensus *Consensus) isRightBlockNumCheck(recvMsg *FBFTMessage) bool {
if recvMsg.BlockNum < consensus.blockNum {
if recvMsg.BlockNum < consensus.BlockNum() {
consensus.getLogger().Debug().
Uint64("MsgBlockNum", recvMsg.BlockNum).
Msg("Wrong BlockNum Received, ignoring!")
return false
} else if recvMsg.BlockNum-consensus.blockNum > MaxBlockNumDiff {
} else if recvMsg.BlockNum-consensus.BlockNum() > MaxBlockNumDiff {
consensus.getLogger().Debug().
Uint64("MsgBlockNum", recvMsg.BlockNum).
Uint64("MaxBlockNumDiff", MaxBlockNumDiff).
@ -122,7 +122,7 @@ func (consensus *Consensus) newBlockSanityChecks(
blockObj *types.Block, recvMsg *FBFTMessage,
) bool {
if blockObj.NumberU64() != recvMsg.BlockNum ||
recvMsg.BlockNum < consensus.blockNum {
recvMsg.BlockNum < consensus.BlockNum() {
consensus.getLogger().Warn().
Uint64("MsgBlockNum", recvMsg.BlockNum).
Uint64("blockNum", blockObj.NumberU64()).
@ -152,12 +152,12 @@ func (consensus *Consensus) onViewChangeSanityCheck(recvMsg *FBFTMessage) bool {
Interface("SendPubKeys", recvMsg.SenderPubkeys).
Msg("[onViewChangeSanityCheck]")
if consensus.blockNum > recvMsg.BlockNum {
if consensus.BlockNum() > recvMsg.BlockNum {
consensus.getLogger().Debug().
Msg("[onViewChange] Message BlockNum Is Low")
return false
}
if consensus.blockNum < recvMsg.BlockNum {
if consensus.BlockNum() < recvMsg.BlockNum {
consensus.getLogger().Warn().
Msg("[onViewChangeSanityCheck] MsgBlockNum is different from my BlockNumber")
return false

@ -3,6 +3,7 @@ package consensus
import (
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/harmony-one/harmony/crypto/bls"
@ -45,7 +46,7 @@ type Consensus struct {
// FBFTLog stores the pbft messages and blocks during FBFT process
FBFTLog *FBFTLog
// phase: different phase of FBFT protocol: pre-prepare, prepare, commit, finish etc
phase FBFTPhase
phase *LockedFBFTPhase
// current indicates what state a node is in
current State
// isBackup declarative the node is in backup mode
@ -131,7 +132,7 @@ type Consensus struct {
// finality of previous consensus in the unit of milliseconds
finality int64
// finalityCounter keep tracks of the finality time
finalityCounter int64
finalityCounter atomic.Value //int64
dHelper *downloadHelper
}
@ -168,6 +169,12 @@ func (consensus *Consensus) GetPublicKeys() multibls.PublicKeys {
return consensus.priKey.GetPublicKeys()
}
func (consensus *Consensus) GetLeaderPubKey() *bls_cosi.PublicKeyWrapper {
consensus.pubKeyLock.Lock()
defer consensus.pubKeyLock.Unlock()
return consensus.LeaderPubKey
}
func (consensus *Consensus) GetPrivateKeys() multibls.PrivateKeys {
return consensus.priKey
}
@ -197,6 +204,10 @@ func (consensus *Consensus) IsBackup() bool {
return consensus.isBackup
}
func (consensus *Consensus) BlockNum() uint64 {
return atomic.LoadUint64(&consensus.blockNum)
}
// New create a new Consensus record
func New(
host p2p.Host, shard uint32, leader p2p.Peer, multiBLSPriKey multibls.PrivateKeys,
@ -209,7 +220,7 @@ func New(
consensus.BlockNumLowChan = make(chan struct{}, 1)
// FBFT related
consensus.FBFTLog = NewFBFTLog()
consensus.phase = FBFTAnnounce
consensus.phase = NewLockedFBFTPhase(FBFTAnnounce)
consensus.current = State{mode: Normal}
// FBFT timeout
consensus.consensusTimeout = createTimeout()

@ -0,0 +1,30 @@
package consensus
import "sync"
type LockedFBFTPhase struct {
mu sync.Mutex
phase FBFTPhase
}
func NewLockedFBFTPhase(initialPhrase FBFTPhase) *LockedFBFTPhase {
return &LockedFBFTPhase{
phase: initialPhrase,
}
}
func (a *LockedFBFTPhase) Set(phrase FBFTPhase) {
a.mu.Lock()
a.phase = phrase
a.mu.Unlock()
}
func (a *LockedFBFTPhase) Get() FBFTPhase {
a.mu.Lock()
defer a.mu.Unlock()
return a.phase
}
func (a *LockedFBFTPhase) String() string {
return a.Get().String()
}

@ -0,0 +1,15 @@
package consensus
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestLockedFBFTPhase(t *testing.T) {
s := NewLockedFBFTPhase(FBFTAnnounce)
require.Equal(t, FBFTAnnounce, s.Get())
s.Set(FBFTCommit)
require.Equal(t, FBFTCommit, s.Get())
}

@ -403,7 +403,9 @@ func (consensus *Consensus) UpdateConsensusInformation() Mode {
consensus.getLogger().Info().
Str("leaderPubKey", leaderPubKey.Bytes.Hex()).
Msg("[UpdateConsensusInformation] Most Recent LeaderPubKey Updated Based on BlockChain")
consensus.pubKeyLock.Lock()
consensus.LeaderPubKey = leaderPubKey
consensus.pubKeyLock.Unlock()
}
}
@ -442,8 +444,11 @@ func (consensus *Consensus) UpdateConsensusInformation() Mode {
// IsLeader check if the node is a leader or not by comparing the public key of
// the node with the leader public key
func (consensus *Consensus) IsLeader() bool {
consensus.pubKeyLock.Lock()
obj := consensus.LeaderPubKey.Object
consensus.pubKeyLock.Unlock()
for _, key := range consensus.priKey {
if key.Pub.Object.IsEqual(consensus.LeaderPubKey.Object) {
if key.Pub.Object.IsEqual(obj) {
return true
}
}
@ -469,13 +474,13 @@ func (consensus *Consensus) SetViewChangingID(viewID uint64) {
// StartFinalityCount set the finality counter to current time
func (consensus *Consensus) StartFinalityCount() {
consensus.finalityCounter = time.Now().UnixNano()
consensus.finalityCounter.Store(time.Now().UnixNano())
}
// FinishFinalityCount calculate the current finality
func (consensus *Consensus) FinishFinalityCount() {
d := time.Now().UnixNano()
consensus.finality = (d - consensus.finalityCounter) / 1000000
consensus.finality = (d - consensus.finalityCounter.Load().(int64)) / 1000000
consensusFinalityHistogram.Observe(float64(consensus.finality))
}
@ -491,7 +496,7 @@ func (consensus *Consensus) switchPhase(subject string, desired FBFTPhase) {
Str("to:", desired.String()).
Str("switchPhase:", subject)
consensus.phase = desired
consensus.phase.Set(desired)
return
}
@ -580,7 +585,7 @@ func (consensus *Consensus) NumSignaturesIncludedInBlock(block *types.Block) uin
// getLogger returns logger for consensus contexts added
func (consensus *Consensus) getLogger() *zerolog.Logger {
logger := utils.Logger().With().
Uint64("myBlock", consensus.blockNum).
Uint64("myBlock", consensus.BlockNum()).
Uint64("myViewID", consensus.GetCurBlockViewID()).
Str("phase", consensus.phase.String()).
Str("mode", consensus.current.Mode().String()).

@ -38,7 +38,7 @@ func TestConsensusInitialization(t *testing.T) {
// FBFTLog
assert.Equal(t, fbtLog, consensus.FBFTLog)
assert.Equal(t, FBFTAnnounce, consensus.phase)
assert.Equal(t, FBFTAnnounce, consensus.phase.Get())
// State / consensus.current
assert.Equal(t, state.mode, consensus.current.mode)

@ -135,7 +135,7 @@ func (consensus *Consensus) finalCommit() {
consensus.getLogger().Info().
Int64("NumCommits", numCommits).
Msg("[finalCommit] Finalizing Consensus")
beforeCatchupNum := consensus.blockNum
beforeCatchupNum := consensus.BlockNum()
leaderPriKey, err := consensus.GetConsensusLeaderPrivateKey()
if err != nil {
@ -188,7 +188,7 @@ func (consensus *Consensus) finalCommit() {
} else {
consensus.getLogger().Info().
Hex("blockHash", curBlockHash[:]).
Uint64("blockNum", consensus.blockNum).
Uint64("blockNum", consensus.BlockNum()).
Msg("[finalCommit] Sent Committed Message")
}
consensus.getLogger().Info().Msg("[finalCommit] Start consensus timer")
@ -203,7 +203,7 @@ func (consensus *Consensus) finalCommit() {
p2p.ConstructMessage(msgToSend))
consensus.getLogger().Info().
Hex("blockHash", curBlockHash[:]).
Uint64("blockNum", consensus.blockNum).
Uint64("blockNum", consensus.BlockNum()).
Hex("lastCommitSig", commitSigAndBitmap).
Msg("[finalCommit] Queued Committed Message")
}
@ -211,7 +211,7 @@ func (consensus *Consensus) finalCommit() {
block.SetCurrentCommitSig(commitSigAndBitmap)
err = consensus.commitBlock(block, FBFTMsg)
if err != nil || consensus.blockNum-beforeCatchupNum != 1 {
if err != nil || consensus.BlockNum()-beforeCatchupNum != 1 {
consensus.getLogger().Err(err).
Uint64("beforeCatchupBlockNum", beforeCatchupNum).
Msg("[finalCommit] Leader failed to commit the confirmed block")
@ -264,7 +264,7 @@ func (consensus *Consensus) finalCommit() {
// BlockCommitSigs returns the byte array of aggregated
// commit signature and bitmap signed on the block
func (consensus *Consensus) BlockCommitSigs(blockNum uint64) ([]byte, error) {
if consensus.blockNum <= 1 {
if consensus.BlockNum() <= 1 {
return nil, nil
}
lastCommits, err := consensus.Blockchain.ReadCommitSig(blockNum)
@ -363,7 +363,7 @@ func (consensus *Consensus) Start(
case <-consensus.syncReadyChan:
consensus.getLogger().Info().Msg("[ConsensusMainLoop] syncReadyChan")
consensus.mutex.Lock()
if consensus.blockNum < consensus.Blockchain.CurrentHeader().Number().Uint64()+1 {
if consensus.BlockNum() < consensus.Blockchain.CurrentHeader().Number().Uint64()+1 {
consensus.SetBlockNum(consensus.Blockchain.CurrentHeader().Number().Uint64() + 1)
consensus.SetViewIDs(consensus.Blockchain.CurrentHeader().ViewID().Uint64() + 1)
mode := consensus.UpdateConsensusInformation()
@ -396,7 +396,7 @@ func (consensus *Consensus) Start(
Uint64("MsgBlockNum", newBlock.NumberU64()).
Msg("[ConsensusMainLoop] Received Proposed New Block!")
if newBlock.NumberU64() < consensus.blockNum {
if newBlock.NumberU64() < consensus.BlockNum() {
consensus.getLogger().Warn().Uint64("newBlockNum", newBlock.NumberU64()).
Msg("[ConsensusMainLoop] received old block, abort")
continue
@ -443,7 +443,7 @@ func (consensus *Consensus) Close() error {
// waitForCommit wait extra 2 seconds for commit phase to finish
func (consensus *Consensus) waitForCommit() {
if consensus.Mode() != Normal || consensus.phase != FBFTCommit {
if consensus.Mode() != Normal || consensus.phase.Get() != FBFTCommit {
return
}
// We only need to wait consensus is in normal commit phase
@ -569,7 +569,7 @@ func (consensus *Consensus) preCommitAndPropose(blk *types.Block) error {
} else {
consensus.getLogger().Info().
Str("blockHash", blk.Hash().Hex()).
Uint64("blockNum", consensus.blockNum).
Uint64("blockNum", consensus.BlockNum()).
Hex("lastCommitSig", bareMinimumCommit).
Msg("[preCommitAndPropose] Sent Committed Message")
}
@ -621,7 +621,7 @@ func (consensus *Consensus) tryCatchup() error {
if consensus.BlockVerifier == nil {
return errors.New("consensus haven't finished initialization")
}
initBN := consensus.blockNum
initBN := consensus.BlockNum()
defer consensus.postCatchup(initBN)
blks, msgs, err := consensus.getLastMileBlocksAndMsg(initBN)
@ -683,7 +683,9 @@ func (consensus *Consensus) commitBlock(blk *types.Block, committedMsg *FBFTMess
func (consensus *Consensus) SetupForNewConsensus(blk *types.Block, committedMsg *FBFTMessage) {
atomic.StoreUint64(&consensus.blockNum, blk.NumberU64()+1)
consensus.SetCurBlockViewID(committedMsg.ViewID + 1)
consensus.pubKeyLock.Lock()
consensus.LeaderPubKey = committedMsg.SenderPubkeys[0]
consensus.pubKeyLock.Unlock()
// Update consensus keys at last so the change of leader status doesn't mess up normal flow
if blk.IsLastBlockInEpoch() {
consensus.SetMode(consensus.UpdateConsensusInformation())
@ -693,15 +695,15 @@ func (consensus *Consensus) SetupForNewConsensus(blk *types.Block, committedMsg
}
func (consensus *Consensus) postCatchup(initBN uint64) {
if initBN < consensus.blockNum {
if initBN < consensus.BlockNum() {
consensus.getLogger().Info().
Uint64("From", initBN).
Uint64("To", consensus.blockNum).
Uint64("To", consensus.BlockNum()).
Msg("[TryCatchup] Caught up!")
consensus.switchPhase("TryCatchup", FBFTAnnounce)
}
// catch up and skip from view change trap
if initBN < consensus.blockNum && consensus.IsViewChangingMode() {
if initBN < consensus.BlockNum() && consensus.IsViewChangingMode() {
consensus.current.SetMode(Normal)
consensus.consensusTimeout[timeoutViewChange].Stop()
}

@ -30,7 +30,7 @@ func (consensus *Consensus) populateMessageFields(
request *msg_pb.ConsensusRequest, blockHash []byte,
) *msg_pb.ConsensusRequest {
request.ViewId = consensus.GetCurBlockViewID()
request.BlockNum = consensus.blockNum
request.BlockNum = consensus.BlockNum()
request.ShardId = consensus.ShardID
// 32 byte block hash
request.BlockHash = blockHash

@ -2,6 +2,7 @@ package consensus
import (
"bytes"
"sync/atomic"
"testing"
"github.com/ethereum/go-ethereum/common"
@ -86,7 +87,7 @@ func TestConstructPreparedMessage(test *testing.T) {
[]*bls.PublicKeyWrapper{&leaderKeyWrapper},
leaderPriKey.Sign(message),
common.BytesToHash(consensus.blockHash[:]),
consensus.blockNum,
consensus.BlockNum(),
consensus.GetCurBlockViewID(),
)
if _, err := consensus.Decider.AddNewVote(
@ -94,7 +95,7 @@ func TestConstructPreparedMessage(test *testing.T) {
[]*bls.PublicKeyWrapper{&validatorKeyWrapper},
validatorPriKey.Sign(message),
common.BytesToHash(consensus.blockHash[:]),
consensus.blockNum,
consensus.BlockNum(),
consensus.GetCurBlockViewID(),
); err != nil {
test.Log(err)
@ -156,7 +157,7 @@ func TestConstructPrepareMessage(test *testing.T) {
consensus.SetCurBlockViewID(2)
consensus.blockHash = [32]byte{}
copy(consensus.blockHash[:], []byte("random"))
consensus.blockNum = 1000
atomic.StoreUint64(&consensus.blockNum, 1000)
sig := priKeyWrapper1.Pri.SignHash(consensus.blockHash[:])
network, err := consensus.construct(msg_pb.MessageType_PREPARE, nil, []*bls.PrivateKeyWrapper{&priKeyWrapper1})
@ -248,7 +249,7 @@ func TestConstructCommitMessage(test *testing.T) {
consensus.SetCurBlockViewID(2)
consensus.blockHash = [32]byte{}
copy(consensus.blockHash[:], []byte("random"))
consensus.blockNum = 1000
atomic.StoreUint64(&consensus.blockNum, 1000)
sigPayload := []byte("payload")

@ -1,26 +1,21 @@
package consensus
// GetConsensusPhase returns the current phase of the consensus
func (c *Consensus) GetConsensusPhase() string {
return c.phase.String()
func (consensus *Consensus) GetConsensusPhase() string {
return consensus.phase.String()
}
// GetConsensusMode returns the current mode of the consensus
func (c *Consensus) GetConsensusMode() string {
return c.current.mode.String()
func (consensus *Consensus) GetConsensusMode() string {
return consensus.current.mode.String()
}
// GetCurBlockViewID returns the current view ID of the consensus
func (c *Consensus) GetCurBlockViewID() uint64 {
return c.current.GetCurBlockViewID()
func (consensus *Consensus) GetCurBlockViewID() uint64 {
return consensus.current.GetCurBlockViewID()
}
// GetViewChangingID returns the current view changing ID of the consensus
func (c *Consensus) GetViewChangingID() uint64 {
return c.current.GetViewChangingID()
}
// GetBlockNum return the current blockNum of the consensus struct
func (c *Consensus) GetBlockNum() uint64 {
return c.blockNum
func (consensus *Consensus) GetViewChangingID() uint64 {
return consensus.current.GetViewChangingID()
}

@ -137,7 +137,7 @@ func (consensus *Consensus) checkDoubleSign(recvMsg *FBFTMessage) bool {
func (consensus *Consensus) couldThisBeADoubleSigner(
recvMsg *FBFTMessage,
) bool {
num, hash := consensus.blockNum, recvMsg.BlockHash
num, hash := consensus.BlockNum(), recvMsg.BlockHash
suspicious := !consensus.FBFTLog.HasMatchingAnnounce(num, hash) ||
!consensus.FBFTLog.HasMatchingPrepared(num, hash)
if suspicious {

@ -78,7 +78,7 @@ func (consensus *Consensus) announce(block *types.Block) {
}
// Construct broadcast p2p message
if err := consensus.msgSender.SendWithRetry(
consensus.blockNum, msg_pb.MessageType_ANNOUNCE, []nodeconfig.GroupID{
consensus.BlockNum(), msg_pb.MessageType_ANNOUNCE, []nodeconfig.GroupID{
nodeconfig.NewGroupIDByShardID(nodeconfig.ShardID(consensus.ShardID)),
}, p2p.ConstructMessage(msgToSend)); err != nil {
consensus.getLogger().Warn().
@ -99,7 +99,7 @@ func (consensus *Consensus) announce(block *types.Block) {
func (consensus *Consensus) onPrepare(recvMsg *FBFTMessage) {
// TODO(audit): make FBFT lookup using map instead of looping through all items.
if !consensus.FBFTLog.HasMatchingViewAnnounce(
consensus.blockNum, consensus.GetCurBlockViewID(), recvMsg.BlockHash,
consensus.BlockNum(), consensus.GetCurBlockViewID(), recvMsg.BlockHash,
) {
consensus.getLogger().Debug().
Uint64("MsgViewID", recvMsg.ViewID).

@ -42,7 +42,7 @@ func (consensus *Consensus) didReachPrepareQuorum() error {
if err := rlp.DecodeBytes(consensus.block, &blockObj); err != nil {
consensus.getLogger().Warn().
Err(err).
Uint64("BlockNum", consensus.blockNum).
Uint64("BlockNum", consensus.BlockNum()).
Msg("[didReachPrepareQuorum] Unparseable block data")
return err
}
@ -69,7 +69,7 @@ func (consensus *Consensus) didReachPrepareQuorum() error {
}
}
if err := consensus.msgSender.SendWithRetry(
consensus.blockNum,
consensus.BlockNum(),
msg_pb.MessageType_PREPARED, []nodeconfig.GroupID{
nodeconfig.NewGroupIDByShardID(nodeconfig.ShardID(consensus.ShardID)),
},
@ -79,7 +79,7 @@ func (consensus *Consensus) didReachPrepareQuorum() error {
} else {
consensus.getLogger().Info().
Hex("blockHash", consensus.blockHash[:]).
Uint64("blockNum", consensus.blockNum).
Uint64("blockNum", consensus.BlockNum()).
Msg("[OnPrepare] Sent Prepared Message!!")
}
consensus.msgSender.StopRetry(msg_pb.MessageType_ANNOUNCE)

@ -175,7 +175,7 @@ func (consensus *Consensus) sendCommitMessages(blockObj *types.Block) {
consensus.getLogger().Warn().Err(err).Msg("[sendCommitMessages] Cannot send commit message!!")
} else {
consensus.getLogger().Info().
Uint64("blockNum", consensus.blockNum).
Uint64("blockNum", consensus.BlockNum()).
Hex("blockHash", consensus.blockHash[:]).
Msg("[sendCommitMessages] Sent Commit Message!!")
}
@ -192,14 +192,14 @@ func (consensus *Consensus) onPrepared(recvMsg *FBFTMessage) {
Uint64("MsgViewID", recvMsg.ViewID).
Msg("[OnPrepared] Received prepared message")
if recvMsg.BlockNum < consensus.blockNum {
if recvMsg.BlockNum < consensus.BlockNum() {
consensus.getLogger().Info().Uint64("MsgBlockNum", recvMsg.BlockNum).
Msg("Wrong BlockNum Received, ignoring!")
return
}
if recvMsg.BlockNum > consensus.blockNum {
if recvMsg.BlockNum > consensus.BlockNum() {
consensus.getLogger().Warn().
Uint64("myBlockNum", consensus.blockNum).
Uint64("myBlockNum", consensus.BlockNum()).
Uint64("MsgBlockNum", recvMsg.BlockNum).
Hex("myBlockHash", consensus.blockHash[:]).
Hex("MsgBlockHash", recvMsg.BlockHash[:]).
@ -245,10 +245,10 @@ func (consensus *Consensus) onPrepared(recvMsg *FBFTMessage) {
}
return
}
if recvMsg.BlockNum > consensus.blockNum {
if recvMsg.BlockNum > consensus.BlockNum() {
consensus.getLogger().Info().
Uint64("MsgBlockNum", recvMsg.BlockNum).
Uint64("blockNum", consensus.blockNum).
Uint64("blockNum", consensus.BlockNum()).
Msg("[OnPrepared] Future Block Received, ignoring!!")
return
}
@ -274,12 +274,12 @@ func (consensus *Consensus) onPrepared(recvMsg *FBFTMessage) {
if blockObj == nil {
return
}
curBlockNum := consensus.blockNum
curBlockNum := consensus.BlockNum()
for _, committedMsg := range consensus.FBFTLog.GetNotVerifiedCommittedMessages(blockObj.NumberU64(), blockObj.Header().ViewID().Uint64(), blockObj.Hash()) {
if committedMsg != nil {
consensus.onCommitted(committedMsg)
}
if curBlockNum < consensus.blockNum {
if curBlockNum < consensus.BlockNum() {
consensus.getLogger().Info().Msg("[OnPrepared] Successfully caught up with committed message")
break
}
@ -297,16 +297,16 @@ func (consensus *Consensus) onCommitted(recvMsg *FBFTMessage) {
Msg("[OnCommitted] Received committed message")
// Ok to receive committed from last block since it could have more signatures
if recvMsg.BlockNum < consensus.blockNum-1 {
if recvMsg.BlockNum < consensus.BlockNum()-1 {
consensus.getLogger().Info().
Uint64("MsgBlockNum", recvMsg.BlockNum).
Msg("Wrong BlockNum Received, ignoring!")
return
}
if recvMsg.BlockNum > consensus.blockNum {
if recvMsg.BlockNum > consensus.BlockNum() {
consensus.getLogger().Info().
Uint64("myBlockNum", consensus.blockNum).
Uint64("myBlockNum", consensus.BlockNum()).
Uint64("MsgBlockNum", recvMsg.BlockNum).
Hex("myBlockHash", consensus.blockHash[:]).
Hex("MsgBlockHash", recvMsg.BlockHash[:]).
@ -372,12 +372,12 @@ func (consensus *Consensus) onCommitted(recvMsg *FBFTMessage) {
}
}
initBn := consensus.blockNum
initBn := consensus.BlockNum()
consensus.tryCatchup()
if recvMsg.BlockNum > consensus.blockNum {
if recvMsg.BlockNum > consensus.BlockNum() {
consensus.getLogger().Info().
Uint64("myBlockNum", consensus.blockNum).
Uint64("myBlockNum", consensus.BlockNum()).
Uint64("MsgBlockNum", recvMsg.BlockNum).
Hex("myBlockHash", consensus.blockHash[:]).
Hex("MsgBlockHash", recvMsg.BlockHash[:]).
@ -395,7 +395,7 @@ func (consensus *Consensus) onCommitted(recvMsg *FBFTMessage) {
consensus.getLogger().Debug().Msg("[OnCommitted] stop bootstrap timer only once")
}
if initBn < consensus.blockNum {
if initBn < consensus.BlockNum() {
consensus.getLogger().Info().Msg("[OnCommitted] Start consensus timer (new block added)")
consensus.consensusTimeout[timeoutConsensus].Start()
}

@ -264,7 +264,9 @@ func (consensus *Consensus) startViewChange() {
// aganist the consensus.LeaderPubKey variable.
// Ideally, we shall use another variable to keep track of the
// leader pubkey in viewchange mode
consensus.pubKeyLock.Lock()
consensus.LeaderPubKey = consensus.getNextLeaderKey(nextViewID)
consensus.pubKeyLock.Unlock()
consensus.getLogger().Warn().
Uint64("nextViewID", nextViewID).
@ -285,7 +287,7 @@ func (consensus *Consensus) startViewChange() {
if err := consensus.vc.InitPayload(
consensus.FBFTLog,
nextViewID,
consensus.blockNum,
consensus.BlockNum(),
consensus.priKey,
members); err != nil {
consensus.getLogger().Error().Err(err).Msg("[startViewChange] Init Payload Error")
@ -299,7 +301,7 @@ func (consensus *Consensus) startViewChange() {
}
msgToSend := consensus.constructViewChangeMessage(&key)
if err := consensus.msgSender.SendWithRetry(
consensus.blockNum,
consensus.BlockNum(),
msg_pb.MessageType_VIEWCHANGE,
[]nodeconfig.GroupID{
nodeconfig.NewGroupIDByShardID(nodeconfig.ShardID(consensus.ShardID))},
@ -325,7 +327,7 @@ func (consensus *Consensus) startNewView(viewID uint64, newLeaderPriKey *bls.Pri
}
if err := consensus.msgSender.SendWithRetry(
consensus.blockNum,
consensus.BlockNum(),
msg_pb.MessageType_NEWVIEW,
[]nodeconfig.GroupID{
nodeconfig.NewGroupIDByShardID(nodeconfig.ShardID(consensus.ShardID))},
@ -471,10 +473,10 @@ func (consensus *Consensus) onNewView(recvMsg *FBFTMessage) {
Msg("[onNewView] Received NewView Message")
// change view and leaderKey to keep in sync with network
if consensus.blockNum != recvMsg.BlockNum {
if consensus.BlockNum() != recvMsg.BlockNum {
consensus.getLogger().Warn().
Uint64("MsgBlockNum", recvMsg.BlockNum).
Uint64("myBlockNum", consensus.blockNum).
Uint64("myBlockNum", consensus.BlockNum()).
Msg("[onNewView] Invalid block number")
return
}
@ -551,7 +553,9 @@ func (consensus *Consensus) onNewView(recvMsg *FBFTMessage) {
// newView message verified success, override my state
consensus.SetViewIDs(recvMsg.ViewID)
consensus.pubKeyLock.Lock()
consensus.LeaderPubKey = senderKey
consensus.pubKeyLock.Unlock()
consensus.ResetViewChangeState()
consensus.msgSender.StopRetry(msg_pb.MessageType_VIEWCHANGE)

@ -24,7 +24,7 @@ func (consensus *Consensus) constructViewChangeMessage(priKey *bls.PrivateKeyWra
Request: &msg_pb.Message_Viewchange{
Viewchange: &msg_pb.ViewChangeRequest{
ViewId: consensus.GetViewChangingID(),
BlockNum: consensus.blockNum,
BlockNum: consensus.BlockNum(),
ShardId: consensus.ShardID,
SenderPubkey: priKey.Pub.Bytes[:],
LeaderPubkey: consensus.LeaderPubKey.Bytes[:],
@ -33,7 +33,7 @@ func (consensus *Consensus) constructViewChangeMessage(priKey *bls.PrivateKeyWra
}
preparedMsgs := consensus.FBFTLog.GetMessagesByTypeSeq(
msg_pb.MessageType_PREPARED, consensus.blockNum,
msg_pb.MessageType_PREPARED, consensus.BlockNum(),
)
preparedMsg := consensus.FBFTLog.FindMessageByMaxViewID(preparedMsgs)
@ -107,7 +107,7 @@ func (consensus *Consensus) constructNewViewMessage(viewID uint64, priKey *bls.P
Request: &msg_pb.Message_Viewchange{
Viewchange: &msg_pb.ViewChangeRequest{
ViewId: viewID,
BlockNum: consensus.blockNum,
BlockNum: consensus.BlockNum(),
ShardId: consensus.ShardID,
SenderPubkey: priKey.Pub.Bytes[:],
},

@ -43,7 +43,7 @@ func TestPhaseSwitching(t *testing.T) {
_, _, consensus, _, err := GenerateConsensusForTesting()
assert.NoError(t, err)
assert.Equal(t, FBFTAnnounce, consensus.phase) // It's a new consensus, we should be at the FBFTAnnounce phase
assert.Equal(t, FBFTAnnounce, consensus.phase.Get()) // It's a new consensus, we should be at the FBFTAnnounce phase.
switches := []phaseSwitch{
{start: FBFTAnnounce, end: FBFTPrepare},
@ -73,10 +73,10 @@ func TestPhaseSwitching(t *testing.T) {
func testPhaseGroupSwitching(t *testing.T, consensus *Consensus, phases []FBFTPhase, startPhase FBFTPhase, desiredPhase FBFTPhase) {
for range phases {
consensus.switchPhase("test", desiredPhase)
assert.Equal(t, desiredPhase, consensus.phase)
assert.Equal(t, desiredPhase, consensus.phase.Get())
}
assert.Equal(t, desiredPhase, consensus.phase)
assert.Equal(t, desiredPhase, consensus.phase.Get())
return
}

@ -23,6 +23,7 @@ import (
"math/big"
"math/rand"
"os"
"sync/atomic"
"testing"
"time"
@ -69,9 +70,13 @@ type testBlockChain struct {
chainHeadFeed *event.Feed
}
func (bc *testBlockChain) SetGasLimit(value uint64) {
atomic.StoreUint64(&bc.gasLimit, value)
}
func (bc *testBlockChain) CurrentBlock() *types.Block {
return types.NewBlock(blockfactory.NewTestHeader().With().
GasLimit(bc.gasLimit).
GasLimit(atomic.LoadUint64(&bc.gasLimit)).
Header(), nil, nil, nil, nil, nil)
}
@ -162,12 +167,14 @@ func createBlockChain() *BlockChainImpl {
return blockchain
}
func setupTxPool() (*TxPool, *ecdsa.PrivateKey) {
func setupTxPool(chain blockChain) (*TxPool, *ecdsa.PrivateKey) {
if chain == nil {
statedb, _ := state.New(common.Hash{}, state.NewDatabase(rawdb.NewMemoryDatabase()))
blockchain := &testBlockChain{statedb, 1e18, new(event.Feed)}
chain = &testBlockChain{statedb, 1e18, new(event.Feed)}
}
key, _ := crypto.GenerateKey()
pool := NewTxPool(testTxPoolConfig, params.TestChainConfig, blockchain, dummyErrorSink)
pool := NewTxPool(testTxPoolConfig, params.TestChainConfig, chain, dummyErrorSink)
return pool, key
}
@ -282,7 +289,7 @@ func TestStateChangeDuringTransactionPoolReset(t *testing.T) {
func TestInvalidTransactions(t *testing.T) {
t.Parallel()
pool, key := setupTxPool()
pool, key := setupTxPool(nil)
defer pool.Stop()
tx := transaction(0, 0, 100, key)
@ -324,8 +331,7 @@ func TestInvalidTransactions(t *testing.T) {
func TestErrorSink(t *testing.T) {
t.Parallel()
pool, key := setupTxPool()
pool.chain = createBlockChain()
pool, key := setupTxPool(createBlockChain())
defer pool.Stop()
testTxErrorSink := types.NewTransactionErrorSink()
@ -396,8 +402,7 @@ func TestErrorSink(t *testing.T) {
func TestCreateValidatorTransaction(t *testing.T) {
t.Parallel()
pool, _ := setupTxPool()
pool.chain = createBlockChain()
pool, _ := setupTxPool(createBlockChain())
defer pool.Stop()
fromKey, _ := crypto.GenerateKey()
@ -422,8 +427,7 @@ func TestCreateValidatorTransaction(t *testing.T) {
func TestMixedTransactions(t *testing.T) {
t.Parallel()
pool, _ := setupTxPool()
pool.chain = createBlockChain()
pool, _ := setupTxPool(createBlockChain())
defer pool.Stop()
fromKey, _ := crypto.GenerateKey()
@ -457,7 +461,7 @@ func TestBlacklistedTransactions(t *testing.T) {
// DO NOT parallelize, test will add accounts to tx pool config.
// Create the pool
pool, _ := setupTxPool()
pool, _ := setupTxPool(nil)
defer pool.Stop()
// Create testing keys
@ -501,7 +505,7 @@ func TestBlacklistedTransactions(t *testing.T) {
func TestTransactionQueue(t *testing.T) {
t.Parallel()
pool, key := setupTxPool()
pool, key := setupTxPool(nil)
defer pool.Stop()
tx := transaction(0, 0, 100, key)
@ -528,7 +532,7 @@ func TestTransactionQueue(t *testing.T) {
t.Error("expected transaction queue to be empty. is", len(pool.queue))
}
pool, key = setupTxPool()
pool, key = setupTxPool(nil)
defer pool.Stop()
tx1 := transaction(0, 0, 100, key)
@ -555,7 +559,7 @@ func TestTransactionQueue(t *testing.T) {
func TestTransactionNegativeValue(t *testing.T) {
t.Parallel()
pool, key := setupTxPool()
pool, key := setupTxPool(nil)
defer pool.Stop()
tx, _ := types.SignTx(
@ -569,9 +573,10 @@ func TestTransactionNegativeValue(t *testing.T) {
}
func TestTransactionChainFork(t *testing.T) {
t.Skip("This test doesn't work with race detector")
t.Parallel()
pool, key := setupTxPool()
pool, key := setupTxPool(nil)
defer pool.Stop()
addr := crypto.PubkeyToAddress(key.PublicKey)
@ -600,18 +605,13 @@ func TestTransactionChainFork(t *testing.T) {
func TestTransactionDoubleNonce(t *testing.T) {
t.Parallel()
pool, key := setupTxPool()
defer pool.Stop()
key, _ := crypto.GenerateKey()
addr := crypto.PubkeyToAddress(key.PublicKey)
resetState := func() {
statedb, _ := state.New(common.Hash{}, state.NewDatabase(rawdb.NewMemoryDatabase()))
statedb.AddBalance(addr, big.NewInt(1000000000000000000))
pool.chain = &testBlockChain{statedb, 1000000, new(event.Feed)}
pool, _ := setupTxPool(&testBlockChain{statedb, 1000000, new(event.Feed)})
defer pool.Stop()
pool.lockedReset(nil, nil)
}
resetState()
signer := types.HomesteadSigner{}
tx1, _ := types.SignTx(
@ -656,7 +656,7 @@ func TestTransactionDoubleNonce(t *testing.T) {
func TestTransactionMissingNonce(t *testing.T) {
t.Parallel()
pool, key := setupTxPool()
pool, key := setupTxPool(nil)
defer pool.Stop()
addr := crypto.PubkeyToAddress(key.PublicKey)
@ -680,7 +680,7 @@ func TestTransactionNonceRecovery(t *testing.T) {
t.Parallel()
const n = 10
pool, key := setupTxPool()
pool, key := setupTxPool(nil)
defer pool.Stop()
addr := crypto.PubkeyToAddress(key.PublicKey)
@ -706,7 +706,7 @@ func TestTransactionDropping(t *testing.T) {
t.Parallel()
// Create a test account and fund it
pool, key := setupTxPool()
pool, key := setupTxPool(nil)
defer pool.Stop()
account, _ := deriveSender(transaction(0, 0, 0, key))
@ -774,7 +774,7 @@ func TestTransactionDropping(t *testing.T) {
t.Errorf("total transaction mismatch: have %d, want %d", pool.all.Count(), 4)
}
// Reduce the block gas limit, check that invalidated transactions are dropped
pool.chain.(*testBlockChain).gasLimit = 100
pool.chain.(*testBlockChain).SetGasLimit(100)
pool.lockedReset(nil, nil)
if _, ok := pool.pending[account].txs.items[tx0.Nonce()]; !ok {
@ -918,7 +918,7 @@ func TestTransactionQueueAccountLimiting(t *testing.T) {
t.Parallel()
// Create a test account and fund it
pool, key := setupTxPool()
pool, key := setupTxPool(nil)
defer pool.Stop()
account, _ := deriveSender(transaction(0, 0, 0, key))
@ -1119,7 +1119,7 @@ func testTransactionLimitingEquivalency(t *testing.T, origin uint64) {
t.Parallel()
// Add a batch of transactions to a pool one by one
pool1, key1 := setupTxPool()
pool1, key1 := setupTxPool(nil)
defer pool1.Stop()
account1, _ := deriveSender(transaction(0, 0, 0, key1))
@ -1131,7 +1131,7 @@ func testTransactionLimitingEquivalency(t *testing.T, origin uint64) {
}
}
// Add a batch of transactions to a pool in one big batch
pool2, key2 := setupTxPool()
pool2, key2 := setupTxPool(nil)
defer pool2.Stop()
account2, _ := deriveSender(transaction(0, 0, 0, key2))
@ -1523,7 +1523,7 @@ func BenchmarkPendingDemotion10000(b *testing.B) { benchmarkPendingDemotion(b, 1
func benchmarkPendingDemotion(b *testing.B, size int) {
// Add a batch of transactions to a pool one by one
pool, key := setupTxPool()
pool, key := setupTxPool(nil)
defer pool.Stop()
account, _ := deriveSender(transaction(0, 0, 0, key))
@ -1548,7 +1548,7 @@ func BenchmarkFuturePromotion10000(b *testing.B) { benchmarkFuturePromotion(b, 1
func benchmarkFuturePromotion(b *testing.B, size int) {
// Add a batch of transactions to a pool one by one
pool, key := setupTxPool()
pool, key := setupTxPool(nil)
defer pool.Stop()
account, _ := deriveSender(transaction(0, 0, 0, key))
@ -1568,7 +1568,7 @@ func benchmarkFuturePromotion(b *testing.B, size int) {
// Benchmarks the speed of iterative transaction insertion.
func BenchmarkPoolInsert(b *testing.B) {
// Generate a batch of transactions to enqueue into the pool
pool, key := setupTxPool()
pool, key := setupTxPool(nil)
defer pool.Stop()
account, _ := deriveSender(transaction(0, 0, 0, key))
@ -1592,7 +1592,7 @@ func BenchmarkPoolBatchInsert10000(b *testing.B) { benchmarkPoolBatchInsert(b, 1
func benchmarkPoolBatchInsert(b *testing.B, size int) {
// Generate a batch of transactions to enqueue into the pool
pool, key := setupTxPool()
pool, key := setupTxPool(nil)
defer pool.Stop()
account, _ := deriveSender(transaction(0, 0, 0, key))

@ -227,6 +227,7 @@ type Block struct {
ReceivedAt time.Time
ReceivedFrom interface{}
commitLock sync.Mutex
// Commit Signatures/Bitmap
commitSigAndBitmap []byte
}
@ -264,11 +265,15 @@ func (b *Block) SetCurrentCommitSig(sigAndBitmap []byte) {
Int("dstLen", len(b.header.LastCommitSignature())).
Msg("SetCurrentCommitSig: sig size mismatch")
}
b.commitLock.Lock()
b.commitSigAndBitmap = sigAndBitmap
b.commitLock.Unlock()
}
// GetCurrentCommitSig get the commit group signature that signed on this block.
func (b *Block) GetCurrentCommitSig() []byte {
b.commitLock.Lock()
defer b.commitLock.Unlock()
return b.commitSigAndBitmap
}

@ -0,0 +1,30 @@
package keylocker
import "sync"
type KeyLocker struct {
m sync.Map
}
func New() *KeyLocker {
return &KeyLocker{}
}
func (a *KeyLocker) Lock(key interface{}, f func() (interface{}, error)) (interface{}, error) {
mu := &sync.Mutex{}
for {
actual, _ := a.m.LoadOrStore(key, mu)
mu2 := actual.(*sync.Mutex)
mu.Lock()
if mu2 != mu {
// acquired someone else lock.
mu.Unlock()
continue
}
rs, err := f()
mu.Unlock()
a.m.Delete(key)
return rs, err
}
}

@ -0,0 +1,17 @@
package keylocker
import "testing"
func TestKeyLocker_Sequential(t *testing.T) {
n := New()
key := 1
v := make(chan struct{}, 1)
n.Lock(key, func() (interface{}, error) {
v <- struct{}{}
return nil, nil
})
<-v // we know that goroutine really executed
n.Lock(key, func() (interface{}, error) {
return nil, nil
})
}

@ -1,35 +0,0 @@
package utils
import (
"sync"
"testing"
"time"
)
var NumThreads = 20
func TestSingleton(t *testing.T) {
var wg sync.WaitGroup
for i := 0; i < NumThreads; i++ {
wg.Add(1)
go func() {
defer wg.Done()
time.Sleep(time.Millisecond)
}()
}
wg.Wait()
n := 100
for i := 0; i < NumThreads; i++ {
wg.Add(1)
go func() {
defer wg.Done()
n++
time.Sleep(time.Millisecond)
}()
}
wg.Wait()
}

@ -1,6 +1,7 @@
package utils
import (
"sync"
"time"
)
@ -19,6 +20,7 @@ type Timeout struct {
state TimeoutState
d time.Duration
start time.Time
mu sync.Mutex
}
// NewTimeout creates a new timeout class
@ -29,18 +31,24 @@ func NewTimeout(d time.Duration) *Timeout {
// Start starts the timeout clock
func (timeout *Timeout) Start() {
timeout.mu.Lock()
timeout.state = Active
timeout.start = time.Now()
timeout.mu.Unlock()
}
// Stop stops the timeout clock
func (timeout *Timeout) Stop() {
timeout.mu.Lock()
timeout.state = Inactive
timeout.start = time.Now()
timeout.mu.Unlock()
}
// CheckExpire checks whether the timeout is reached/expired
func (timeout *Timeout) CheckExpire() bool {
timeout.mu.Lock()
defer timeout.mu.Unlock()
if timeout.state == Active && time.Since(timeout.start) > timeout.d {
timeout.state = Expired
}
@ -52,17 +60,23 @@ func (timeout *Timeout) CheckExpire() bool {
// Duration returns the duration period of timeout
func (timeout *Timeout) Duration() time.Duration {
timeout.mu.Lock()
defer timeout.mu.Unlock()
return timeout.d
}
// SetDuration set new duration for the timer
func (timeout *Timeout) SetDuration(nd time.Duration) {
timeout.mu.Lock()
timeout.d = nd
timeout.mu.Unlock()
}
// IsActive checks whether timeout clock is active;
// A timeout is active means it's not stopped caused by stop
// and also not expired with time elapses longer than duration from start
func (timeout *Timeout) IsActive() bool {
timeout.mu.Lock()
defer timeout.mu.Unlock()
return timeout.state == Active
}

@ -125,7 +125,7 @@ func (node *Node) GetConsensusCurViewID() uint64 {
// GetConsensusBlockNum returns the current block number of the consensus
func (node *Node) GetConsensusBlockNum() uint64 {
return node.Consensus.GetBlockNum()
return node.Consensus.BlockNum()
}
// GetConsensusInternal returns consensus internal data

@ -134,7 +134,7 @@ func (node *Node) ProposeNewBlock(commitSigs chan []byte) (*types.Block, error)
header := node.Worker.GetCurrentHeader()
// Update worker's current header and
// state data in preparation to propose/process new transactions
leaderKey := node.Consensus.LeaderPubKey
leaderKey := node.Consensus.GetLeaderPubKey()
var (
coinbase = node.GetAddressForBLSKey(leaderKey.Object, header.Epoch())
beneficiary = coinbase

@ -190,8 +190,8 @@ func NewHost(cfg HostConfig) (Host, error) {
priKey: key,
discovery: disc,
security: security,
onConnections: []ConnectCallback{},
onDisconnects: []DisconnectCallback{},
onConnections: ConnectCallbacks{},
onDisconnects: DisconnectCallbacks{},
logger: &subLogger,
ctx: ctx,
cancel: cancel,
@ -218,8 +218,8 @@ type HostV2 struct {
security security.Security
logger *zerolog.Logger
blocklist libp2p_pubsub.Blacklist
onConnections []ConnectCallback
onDisconnects []DisconnectCallback
onConnections ConnectCallbacks
onDisconnects DisconnectCallbacks
ctx context.Context
cancel func()
}
@ -433,7 +433,7 @@ func (host *HostV2) ListenClose(net libp2p_network.Network, addr ma.Multiaddr) {
func (host *HostV2) Connected(net libp2p_network.Network, conn libp2p_network.Conn) {
host.logger.Info().Interface("node", conn.RemotePeer()).Msg("peer connected")
for _, function := range host.onConnections {
for _, function := range host.onConnections.GetAll() {
if err := function(net, conn); err != nil {
host.logger.Error().Err(err).Interface("node", conn.RemotePeer()).Msg("failed on peer connected callback")
}
@ -444,7 +444,7 @@ func (host *HostV2) Connected(net libp2p_network.Network, conn libp2p_network.Co
func (host *HostV2) Disconnected(net libp2p_network.Network, conn libp2p_network.Conn) {
host.logger.Info().Interface("node", conn.RemotePeer()).Msg("peer disconnected")
for _, function := range host.onDisconnects {
for _, function := range host.onDisconnects.GetAll() {
if err := function(conn); err != nil {
host.logger.Error().Err(err).Interface("node", conn.RemotePeer()).Msg("failed on peer disconnected callback")
}
@ -462,11 +462,11 @@ func (host *HostV2) ClosedStream(net libp2p_network.Network, stream libp2p_netwo
}
func (host *HostV2) SetConnectCallback(callback ConnectCallback) {
host.onConnections = append(host.onConnections, callback)
host.onConnections.Add(callback)
}
func (host *HostV2) SetDisconnectCallback(callback DisconnectCallback) {
host.onDisconnects = append(host.onDisconnects, callback)
host.onDisconnects.Add(callback)
}
// NamedTopic represents pubsub topic

@ -6,6 +6,7 @@ import (
"github.com/harmony-one/harmony/consensus/engine"
"github.com/harmony-one/harmony/core/types"
shardingconfig "github.com/harmony-one/harmony/internal/configs/sharding"
"github.com/harmony-one/harmony/internal/utils/keylocker"
"github.com/pkg/errors"
)
@ -20,12 +21,14 @@ type chainHelper interface {
type chainHelperImpl struct {
chain engine.ChainReader
schedule shardingconfig.Schedule
keyLocker *keylocker.KeyLocker
}
func newChainHelper(chain engine.ChainReader, schedule shardingconfig.Schedule) *chainHelperImpl {
return &chainHelperImpl{
chain: chain,
schedule: schedule,
keyLocker: keylocker.New(),
}
}
@ -89,9 +92,8 @@ func (ch *chainHelperImpl) getBlocksByHashes(hs []common.Hash) ([]*types.Block,
return blocks, nil
}
var errBlockNotFound = errors.New("block not found")
func (ch *chainHelperImpl) getBlockWithSigByHeader(header *block.Header) (*types.Block, error) {
rs, err := ch.keyLocker.Lock(header.Number().Uint64(), func() (interface{}, error) {
b := ch.chain.GetBlock(header.Hash(), header.Number().Uint64())
if b == nil {
return nil, nil
@ -102,6 +104,13 @@ func (ch *chainHelperImpl) getBlockWithSigByHeader(header *block.Header) (*types
}
b.SetCurrentCommitSig(commitSig)
return b, nil
})
if err != nil {
return nil, err
}
return rs.(*types.Block), nil
}
func (ch *chainHelperImpl) getBlockSigAndBitmap(header *block.Header) ([]byte, error) {

@ -2,6 +2,7 @@ package sync
import (
"context"
"sync"
"testing"
"time"
@ -50,11 +51,12 @@ func TestProtocol_advertiseLoop(t *testing.T) {
time.Sleep(150 * time.Millisecond)
close(p.closeC)
if len(disc.advCnt) != len(p.supportedVersions()) {
t.Errorf("unexpected advertise topic count: %v / %v", len(disc.advCnt),
advCnt := disc.Extract()
if len(advCnt) != len(p.supportedVersions()) {
t.Errorf("unexpected advertise topic count: %v / %v", len(advCnt),
len(p.supportedVersions()))
}
for _, cnt := range disc.advCnt {
for _, cnt := range advCnt {
if cnt < 1 {
t.Errorf("unexpected discovery count: %v", cnt)
}
@ -64,6 +66,7 @@ func TestProtocol_advertiseLoop(t *testing.T) {
type testDiscovery struct {
advCnt map[string]int
sleep time.Duration
mu sync.Mutex
}
func newTestDiscovery(discInterval time.Duration) *testDiscovery {
@ -82,10 +85,20 @@ func (disc *testDiscovery) Close() error {
}
func (disc *testDiscovery) Advertise(ctx context.Context, ns string) (time.Duration, error) {
disc.mu.Lock()
defer disc.mu.Unlock()
disc.advCnt[ns]++
return disc.sleep, nil
}
func (disc *testDiscovery) Extract() map[string]int {
disc.mu.Lock()
defer disc.mu.Unlock()
var out map[string]int
out, disc.advCnt = disc.advCnt, make(map[string]int)
return out
}
func (disc *testDiscovery) FindPeers(ctx context.Context, ns string, peerLimit int) (<-chan libp2p_peer.AddrInfo, error) {
return nil, nil
}

@ -0,0 +1,41 @@
package p2p
import "sync"
type ConnectCallbacks struct {
cbs []ConnectCallback
mu sync.RWMutex
}
func (a *ConnectCallbacks) Add(cb ConnectCallback) {
a.mu.Lock()
defer a.mu.Unlock()
a.cbs = append(a.cbs, cb)
}
func (a *ConnectCallbacks) GetAll() []ConnectCallback {
a.mu.RLock()
defer a.mu.RUnlock()
out := make([]ConnectCallback, len(a.cbs))
copy(out, a.cbs)
return out
}
type DisconnectCallbacks struct {
cbs []DisconnectCallback
mu sync.RWMutex
}
func (a *DisconnectCallbacks) Add(cb DisconnectCallback) {
a.mu.Lock()
defer a.mu.Unlock()
a.cbs = append(a.cbs, cb)
}
func (a *DisconnectCallbacks) GetAll() []DisconnectCallback {
a.mu.RLock()
defer a.mu.RUnlock()
out := make([]DisconnectCallback, len(a.cbs))
copy(out, a.cbs)
return out
}

@ -0,0 +1,37 @@
package p2p
import (
"reflect"
"testing"
libp2p_network "github.com/libp2p/go-libp2p-core/network"
"github.com/stretchr/testify/require"
)
func TestConnectCallbacks(t *testing.T) {
cbs := ConnectCallbacks{}
fn := func(net libp2p_network.Network, conn libp2p_network.Conn) error {
return nil
}
require.Equal(t, 0, len(cbs.GetAll()))
cbs.Add(fn)
require.Equal(t, 1, len(cbs.GetAll()))
require.Equal(t, reflect.ValueOf(fn).Pointer(), reflect.ValueOf(cbs.GetAll()[0]).Pointer())
}
func TestDisConnectCallbacks(t *testing.T) {
cbs := DisconnectCallbacks{}
fn := func(conn libp2p_network.Conn) error {
return nil
}
require.Equal(t, 0, len(cbs.GetAll()))
cbs.Add(fn)
require.Equal(t, 1, len(cbs.GetAll()))
require.Equal(t, reflect.ValueOf(fn).Pointer(), reflect.ValueOf(cbs.GetAll()[0]).Pointer())
}

@ -82,7 +82,7 @@ fi
echo "Running go test..."
# Fix https://github.com/golang/go/issues/44129#issuecomment-788351567
go get -t ./...
if go test -v -count=1 ./...
if go test -v -count=1 -vet=all -race ./...
then
echo "go test succeeded."
else

Loading…
Cancel
Save