make resharding logic work

pull/575/head
Rongjian Lan 6 years ago
parent 0bc22a5784
commit ccb2b7ab2e
  1. 2
      api/service/staking/service.go
  2. 2
      core/blockchain.go
  3. 73
      core/resharding.go
  4. 52
      core/resharding_test.go
  5. 11
      core/types/shard_state.go
  6. 44
      core/types/shard_state_test.go
  7. 11
      node/node_handler.go
  8. 2
      node/service_setup.go

@ -87,7 +87,7 @@ func (s *Service) Run() {
if s.IsStaked() {
return
}
s.DoService()
//s.DoService()
return
case <-s.stopChan:
return

@ -70,7 +70,7 @@ const (
// BlocksPerEpoch is the number of blocks in one epoch
// currently set to small number for testing
// in future, this need to be adjusted dynamically instead of constant
BlocksPerEpoch = 5
BlocksPerEpoch = 10
// BlockChainVersion ensures that an incompatible database forces a resync from scratch.
BlockChainVersion = 3

@ -3,7 +3,6 @@ package core
import (
"encoding/binary"
"encoding/hex"
"math"
"math/rand"
"sort"
@ -19,12 +18,16 @@ import (
const (
// InitialSeed is the initial random seed, a magic number to answer everything, remove later
InitialSeed uint32 = 42
// GenesisEpoch is the number of the first genesis epoch.
// GenesisEpoch is the number of the genesis epoch.
GenesisEpoch = 0
// FirstEpoch is the number of the first epoch.
FirstEpoch = 1
// GenesisShardNum is the number of shard at genesis
GenesisShardNum = 3
GenesisShardNum = 4
// GenesisShardSize is the size of each shard at genesis
GenesisShardSize = 10
// CuckooRate is the percentage of nodes getting reshuffled in the second step of cuckoo resharding.
CuckooRate = 0.1
)
// ShardingState is data structure hold the sharding state
@ -67,9 +70,12 @@ func (ss *ShardingState) cuckooResharding(percent float64) {
}
numKicked := int(percent * float64(len(ss.shardState[i].NodeList)))
if numKicked == 0 {
numKicked++
numKicked++ // At least kick one node out
}
length := len(ss.shardState[i].NodeList)
if length-numKicked <= 0 {
continue // Never empty a shard
}
tmp := ss.shardState[i].NodeList[length-numKicked:]
kickedNodes = append(kickedNodes, tmp...)
ss.shardState[i].NodeList = ss.shardState[i].NodeList[:length-numKicked]
@ -138,33 +144,40 @@ func GetShardingStateFromBlockChain(bc *BlockChain, epoch uint64) *ShardingState
}
// CalculateNewShardState get sharding state from previous epoch and calculate sharding state for new epoch
// TODO: currently, we just mock everything
func CalculateNewShardState(bc *BlockChain, epoch uint64, stakeInfo *map[common.Address]*structs.StakeInfo) types.ShardState {
if epoch == GenesisEpoch {
return GetInitShardState()
}
ss := GetShardingStateFromBlockChain(bc, epoch-1)
if epoch == FirstEpoch {
newNodes := []types.NodeID{}
for addr, stakeInfo := range *stakeInfo {
newNodes = append(newNodes, types.NodeID{addr.Hex(), hex.EncodeToString(stakeInfo.BlsAddress[:])})
}
rand.Seed(int64(ss.rnd))
Shuffle(newNodes)
utils.GetLogInstance().Info("[resharding] New Nodes", "data", newNodes)
for i, nid := range newNodes {
id := i%(GenesisShardNum-1) + 1 // assign the node to one of the empty shard
ss.shardState[id].NodeList = append(ss.shardState[id].NodeList, nid)
}
utils.GetLogInstance().Info("State", "data", ss)
return ss.shardState
}
newNodeList := ss.UpdateShardingState(stakeInfo)
percent := ss.calculateKickoutRate(newNodeList)
utils.GetLogInstance().Info("Kickout Percentage", "percentage", percent)
ss.Reshard(newNodeList, percent)
utils.GetLogInstance().Info("Cuckoo Rate", "percentage", CuckooRate)
ss.Reshard(newNodeList, CuckooRate)
return ss.shardState
}
// UpdateShardingState remove the unstaked nodes and returns the newly staked node Ids.
func (ss *ShardingState) UpdateShardingState(stakeInfo *map[common.Address]*structs.StakeInfo) []types.NodeID {
oldAddresses := make(map[common.Address]bool)
oldAddresses := make(map[string]bool) // map of bls addresses
for _, shard := range ss.shardState {
newNodeList := shard.NodeList[:0]
for _, nodeID := range shard.NodeList {
addr := common.Address{}
addrBytes, err := hex.DecodeString(string(nodeID))
if err != nil {
utils.GetLogInstance().Error("Failed to decode address hex")
}
addr.SetBytes(addrBytes)
oldAddresses[addr] = true
_, ok := (*stakeInfo)[addr]
oldAddresses[nodeID.BlsAddress] = true
_, ok := (*stakeInfo)[common.HexToAddress(nodeID.EcdsaAddress)]
if ok {
newNodeList = append(newNodeList, nodeID)
} else {
@ -175,28 +188,15 @@ func (ss *ShardingState) UpdateShardingState(stakeInfo *map[common.Address]*stru
}
newAddresses := []types.NodeID{}
for addr := range *stakeInfo {
_, ok := oldAddresses[addr]
for addr, info := range *stakeInfo {
_, ok := oldAddresses[addr.Hex()]
if !ok {
newAddresses = append(newAddresses, types.NodeID(addr.Hex()))
newAddresses = append(newAddresses, types.NodeID{hex.EncodeToString(info.BlsAddress[:]), addr.Hex()})
}
}
return newAddresses
}
// calculateKickoutRate calculates the cuckoo rule kick out rate in order to make committee balanced
func (ss *ShardingState) calculateKickoutRate(newNodeList []types.NodeID) float64 {
numActiveCommittees := ss.numShards / 2
newNodesPerShard := math.Ceil(float64(len(newNodeList)) / float64(numActiveCommittees))
ss.sortCommitteeBySize()
L := len(ss.shardState[0].NodeList)
if L == 0 {
return 0.0
}
rate := newNodesPerShard / float64(L)
return math.Max(0.1, math.Min(rate, 1.0))
}
// GetInitShardState returns the initial shard state at genesis.
// TODO: make the deploy.sh config file in sync with genesis constants.
func GetInitShardState() types.ShardState {
@ -206,10 +206,11 @@ func GetInitShardState() types.ShardState {
if i == 0 {
for j := 0; j < GenesisShardSize; j++ {
priKey := bls.SecretKey{}
priKey.SetHexString(contract.InitialBeaconChainAccounts[j].Private)
priKey.SetHexString(contract.InitialBeaconChainBLSAccounts[j].Private)
addrBytes := priKey.GetPublicKey().GetAddress()
address := hex.EncodeToString(addrBytes[:])
com.NodeList = append(com.NodeList, types.NodeID(address))
blsAddr := hex.EncodeToString(addrBytes[:])
// TODO: directly read address for bls too
com.NodeList = append(com.NodeList, types.NodeID{blsAddr, contract.InitialBeaconChainAccounts[j].Address})
}
}
shardState = append(shardState, com)

@ -18,7 +18,7 @@ func fakeGetInitShardState(numberOfShards, numOfNodes int) types.ShardState {
com := types.Committee{ShardID: sid}
for j := 0; j < numOfNodes; j++ {
nid := strconv.Itoa(int(rand.Int63()))
com.NodeList = append(com.NodeList, types.NodeID(nid))
com.NodeList = append(com.NodeList, types.NodeID{nid, nid})
}
shardState = append(shardState, com)
}
@ -31,7 +31,7 @@ func fakeNewNodeList(seed int64) []types.NodeID {
nodeList := []types.NodeID{}
for i := 0; i < numNewNodes; i++ {
nid := strconv.Itoa(int(rand.Int63()))
nodeList = append(nodeList, types.NodeID(nid))
nodeList = append(nodeList, types.NodeID{nid, nid})
}
return nodeList
}
@ -43,16 +43,16 @@ func TestFakeNewNodeList(t *testing.T) {
func TestShuffle(t *testing.T) {
nodeList := []types.NodeID{
"node1",
"node2",
"node3",
"node4",
"node5",
"node6",
"node7",
"node8",
"node9",
"node10",
{"node1", "node1"},
{"node2", "node2"},
{"node3", "node3"},
{"node4", "node4"},
{"node5", "node5"},
{"node6", "node6"},
{"node7", "node7"},
{"node8", "node8"},
{"node9", "node9"},
{"node10", "node10"},
}
cpList := []types.NodeID{}
@ -83,18 +83,18 @@ func TestUpdateShardState(t *testing.T) {
shardState := fakeGetInitShardState(6, 10)
ss := &ShardingState{epoch: 1, rnd: 42, shardState: shardState, numShards: len(shardState)}
newNodeList := []types.NodeID{
"node1",
"node2",
"node3",
"node4",
"node5",
"node6",
{"node1", "node1"},
{"node2", "node2"},
{"node3", "node3"},
{"node4", "node4"},
{"node5", "node5"},
{"node6", "node6"},
}
ss.Reshard(newNodeList, 0.2)
assert.Equal(t, 6, ss.numShards)
for _, shard := range ss.shardState {
assert.Equal(t, string(shard.Leader), string(shard.NodeList[0]))
assert.Equal(t, shard.Leader.BlsAddress, shard.NodeList[0].BlsAddress)
}
}
@ -102,20 +102,12 @@ func TestAssignNewNodes(t *testing.T) {
shardState := fakeGetInitShardState(2, 2)
ss := &ShardingState{epoch: 1, rnd: 42, shardState: shardState, numShards: len(shardState)}
newNodes := []types.NodeID{
"node1",
"node2",
"node3",
{"node1", "node1"},
{"node2", "node2"},
{"node3", "node3"},
}
ss.assignNewNodes(newNodes)
assert.Equal(t, 2, ss.numShards)
assert.Equal(t, 5, len(ss.shardState[0].NodeList))
}
func TestCalculateKickoutRate(t *testing.T) {
shardState := fakeGetInitShardState(6, 10)
ss := &ShardingState{epoch: 1, rnd: 42, shardState: shardState, numShards: len(shardState)}
newNodeList := fakeNewNodeList(42)
percent := ss.calculateKickoutRate(newNodeList)
assert.Equal(t, 0.2, percent)
}

@ -11,8 +11,11 @@ import (
// ShardState is the collection of all committees
type ShardState []Committee
// NodeID represents node id.
type NodeID string
// NodeID represents node id (BLS address).
type NodeID struct {
EcdsaAddress string
BlsAddress string
}
// Committee contains the active nodes in one shard
type Committee struct {
@ -55,10 +58,10 @@ func (ss ShardState) Hash() (h common.Hash) {
// CompareNodeID compares two nodes by their ID; used to sort node list
func CompareNodeID(n1 NodeID, n2 NodeID) int {
return strings.Compare(string(n1), string(n2))
return strings.Compare(n1.BlsAddress, n2.BlsAddress)
}
// Serialize serialize NodeID into bytes
func (n NodeID) Serialize() []byte {
return []byte(string(n))
return []byte(n.BlsAddress)
}

@ -7,14 +7,14 @@ import (
func TestGetHashFromNodeList(t *testing.T) {
l1 := []NodeID{
"node1",
"node2",
"node3",
{"node1", "node1"},
{"node2", "node2"},
{"node3", "node3"},
}
l2 := []NodeID{
"node2",
"node1",
"node3",
{"node2", "node2"},
{"node1", "node1"},
{"node3", "node3"},
}
h1 := GetHashFromNodeList(l1)
h2 := GetHashFromNodeList(l2)
@ -27,20 +27,20 @@ func TestGetHashFromNodeList(t *testing.T) {
func TestHash(t *testing.T) {
com1 := Committee{
ShardID: 22,
Leader: "node11",
Leader: NodeID{"node11", "node11"},
NodeList: []NodeID{
"node11",
"node22",
"node1",
{"node11", "node11"},
{"node22", "node22"},
{"node1", "node1"},
},
}
com2 := Committee{
ShardID: 2,
Leader: "node4",
Leader: NodeID{"node4", "node4"},
NodeList: []NodeID{
"node4",
"node5",
"node6",
{"node4", "node4"},
{"node5", "node5"},
{"node6", "node6"},
},
}
shardState1 := ShardState{com1, com2}
@ -48,20 +48,20 @@ func TestHash(t *testing.T) {
com3 := Committee{
ShardID: 2,
Leader: "node4",
Leader: NodeID{"node4", "node4"},
NodeList: []NodeID{
"node6",
"node5",
"node4",
{"node6", "node6"},
{"node5", "node5"},
{"node4", "node4"},
},
}
com4 := Committee{
ShardID: 22,
Leader: "node11",
Leader: NodeID{"node11", "node11"},
NodeList: []NodeID{
"node1",
"node11",
"node22",
{"node1", "node1"},
{"node11", "node11"},
{"node22", "node22"},
},
}

@ -286,18 +286,15 @@ func (node *Node) PostConsensusProcessing(newBlock *types.Block) {
}()
}
if node.NodeConfig.Role() == nodeconfig.BeaconLeader {
utils.GetLogInstance().Info("Updating staking list")
node.UpdateStakingList(node.QueryStakeInfo())
node.printStakingList()
}
utils.GetLogInstance().Info("Updating staking list")
node.UpdateStakingList(node.QueryStakeInfo())
node.printStakingList()
node.blockchain.StoreNewShardState(newBlock, &node.CurrentStakes)
}
// AddNewBlock is usedd to add new block into the blockchain.
func (node *Node) AddNewBlock(newBlock *types.Block) {
blockNum, err := node.blockchain.InsertChain([]*types.Block{newBlock})
node.blockchain.StoreNewShardState(newBlock, &node.CurrentStakes)
if err != nil {
utils.GetLogInstance().Debug("Error adding new block to blockchain", "blockNum", blockNum, "Error", err)
} else {

@ -77,7 +77,7 @@ func (node *Node) setupForBeaconValidator() {
func (node *Node) setupForNewNode() {
// TODO determine the role of new node, currently assume it is beacon node
nodeConfig, chanPeer := node.initNodeConfiguration(true, false)
nodeConfig, chanPeer := node.initNodeConfiguration(true, true)
// Register staking service.
node.serviceManager.RegisterService(service.Staking, staking.New(node.host, node.AccountKey, node.beaconChain, node.NodeConfig.ConsensusPubKey.GetAddress()))

Loading…
Cancel
Save