diff --git a/cmd/harmony/main.go b/cmd/harmony/main.go index c9c4ec1ea..a90faf628 100644 --- a/cmd/harmony/main.go +++ b/cmd/harmony/main.go @@ -453,20 +453,13 @@ func main() { if *nodeType == "validator" { setupInitialAccount() if *stakingFlag { - var blsPubKey shard.BlsPublicKey - pubKey := nodeconfig.GetDefaultConfig().ConsensusPubKey - if err := blsPubKey.FromLibBLSPublicKey(pubKey); err != nil { - _, _ = fmt.Fprint(os.Stderr, - "ERROR cannot convert libbls pubkey to internal form: %s", - err) + shardID, err := nodeconfig.GetDefaultConfig().ShardIDFromConsensusKey() + if err != nil { + _, _ = fmt.Fprintf(os.Stderr, + "ERROR cannot determine shard to join: %s", err) os.Exit(1) } - // Use the number of shards as of staking epoch. - chainConfig := nodeconfig.NetworkType(*networkType).ChainConfig() - stakingEpochShardConfig := shard.Schedule.InstanceForEpoch(chainConfig.StakingEpoch) - numShardsBig := big.NewInt(int64(stakingEpochShardConfig.NumShards())) - shardIDBig := new(big.Int).Mod(blsPubKey.Big(), numShardsBig) - initialAccount.ShardID = uint32(shardIDBig.Uint64()) + initialAccount.ShardID = shardID } } fmt.Printf("%s mode; node key %s -> shard %d\n", diff --git a/internal/configs/node/config.go b/internal/configs/node/config.go index 45441c9ea..75908ee04 100644 --- a/internal/configs/node/config.go +++ b/internal/configs/node/config.go @@ -6,13 +6,16 @@ package nodeconfig import ( "crypto/ecdsa" "fmt" + "math/big" "sync" "github.com/harmony-one/bls/ffi/go/bls" p2p_crypto "github.com/libp2p/go-libp2p-crypto" + "github.com/pkg/errors" shardingconfig "github.com/harmony-one/harmony/internal/configs/sharding" "github.com/harmony-one/harmony/internal/params" + "github.com/harmony-one/harmony/shard" ) // Role defines a role of a node. @@ -271,6 +274,21 @@ func SetShardingSchedule(schedule shardingconfig.Schedule) { } } +// ShardIDFromConsensusKey returns the shard ID statically determined from the +// consensus key. +func (conf *ConfigType) ShardIDFromConsensusKey() (uint32, error) { + var pubKey shard.BlsPublicKey + if err := pubKey.FromLibBLSPublicKey(conf.ConsensusPubKey); err != nil { + return 0, errors.Wrapf(err, + "cannot convert libbls public key %s to internal form", + conf.ConsensusPubKey.SerializeToHexStr()) + } + epoch := conf.networkType.ChainConfig().StakingEpoch + numShards := conf.shardingSchedule.InstanceForEpoch(epoch).NumShards() + shardID := new(big.Int).Mod(pubKey.Big(), big.NewInt(int64(numShards))) + return uint32(shardID.Uint64()), nil +} + // ChainConfig returns the chain configuration for the network type. func (t NetworkType) ChainConfig() params.ChainConfig { switch t {