Merge pull request #3219 from JackyWYX/refactor_bls_read
[cmd] Migrate BLS load logic to harmony binary.pull/3266/head
commit
5062c178a7
@ -0,0 +1,129 @@ |
||||
package main |
||||
|
||||
import ( |
||||
"errors" |
||||
"flag" |
||||
"fmt" |
||||
"os" |
||||
"strings" |
||||
"sync" |
||||
|
||||
"github.com/harmony-one/harmony/internal/blsgen" |
||||
nodeconfig "github.com/harmony-one/harmony/internal/configs/node" |
||||
"github.com/harmony-one/harmony/multibls" |
||||
) |
||||
|
||||
var ( |
||||
blsKeyFile = flag.String("blskey_file", "", "The encrypted file of bls serialized private key by passphrase.") |
||||
blsFolder = flag.String("blsfolder", ".hmy/blskeys", "The folder that stores the bls keys and corresponding passphrases; e.g. <blskey>.key and <blskey>.pass; all bls keys mapped to same shard") |
||||
maxBLSKeysPerNode = flag.Int("max_bls_keys_per_node", 10, "Maximum number of bls keys allowed per node (default 4)") |
||||
|
||||
// TODO(jacky): rename it to a better name with cobra alias
|
||||
blsPass = flag.String("blspass", "default", "The source for bls passphrases. (default, no-prompt, prompt, file:$PASS_FILE, none)") |
||||
persistPass = flag.Bool("save-passphrase", false, "Whether the prompt passphrase is saved after prompt.") |
||||
awsConfigSource = flag.String("aws-config-source", "default", "The source for aws config. (default, prompt, file:$CONFIG_FILE, none)") |
||||
) |
||||
|
||||
var ( |
||||
multiBLSPriKey multibls.PrivateKeys |
||||
onceLoadBLSKey sync.Once |
||||
) |
||||
|
||||
// setupConsensusKeys load bls keys and set the keys to nodeConfig. Return the loaded public keys.
|
||||
func setupConsensusKeys(config *nodeconfig.ConfigType) multibls.PublicKeys { |
||||
onceLoadBLSKey.Do(func() { |
||||
var err error |
||||
multiBLSPriKey, err = loadBLSKeys() |
||||
if err != nil { |
||||
fmt.Fprintf(os.Stderr, "ERROR when loading bls key: %v\n", err) |
||||
os.Exit(100) |
||||
} |
||||
fmt.Printf("Successfully loaded %v BLS keys\n", len(multiBLSPriKey)) |
||||
}) |
||||
config.ConsensusPriKey = multiBLSPriKey |
||||
return multiBLSPriKey.GetPublicKeys() |
||||
} |
||||
|
||||
func loadBLSKeys() (multibls.PrivateKeys, error) { |
||||
config, err := parseBLSLoadingConfig() |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
keys, err := blsgen.LoadKeys(config) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
if len(keys) == 0 { |
||||
return nil, fmt.Errorf("0 bls keys loaded") |
||||
} |
||||
if len(keys) >= *maxBLSKeysPerNode { |
||||
return nil, fmt.Errorf("bls keys exceed maximum count %v", *maxBLSKeysPerNode) |
||||
} |
||||
return keys, err |
||||
} |
||||
|
||||
func parseBLSLoadingConfig() (blsgen.Config, error) { |
||||
var ( |
||||
config blsgen.Config |
||||
err error |
||||
) |
||||
if len(*blsKeyFile) != 0 { |
||||
config.MultiBlsKeys = strings.Split(*blsKeyFile, ",") |
||||
} |
||||
config.BlsDir = blsFolder |
||||
|
||||
config, err = parseBLSPass(config, *blsPass) |
||||
if err != nil { |
||||
return blsgen.Config{}, err |
||||
} |
||||
config, err = parseAwsConfigSrc(config, *awsConfigSource) |
||||
if err != nil { |
||||
return blsgen.Config{}, err |
||||
} |
||||
return config, nil |
||||
} |
||||
|
||||
func parseBLSPass(config blsgen.Config, src string) (blsgen.Config, error) { |
||||
methodArgs := strings.SplitN(src, ":", 2) |
||||
method := methodArgs[0] |
||||
|
||||
switch method { |
||||
case "default", "stdin": |
||||
config.PassSrcType = blsgen.PassSrcAuto |
||||
case "file": |
||||
config.PassSrcType = blsgen.PassSrcFile |
||||
if len(methodArgs) < 2 { |
||||
return blsgen.Config{}, errors.New("must specify passphrase file") |
||||
} |
||||
config.PassFile = &methodArgs[1] |
||||
case "no-prompt": |
||||
config.PassSrcType = blsgen.PassSrcFile |
||||
case "prompt": |
||||
config.PassSrcType = blsgen.PassSrcPrompt |
||||
config.PersistPassphrase = *persistPass |
||||
case "none": |
||||
config.PassSrcType = blsgen.PassSrcNil |
||||
} |
||||
config.PersistPassphrase = *persistPass |
||||
return config, nil |
||||
} |
||||
|
||||
func parseAwsConfigSrc(config blsgen.Config, src string) (blsgen.Config, error) { |
||||
methodArgs := strings.SplitN(src, ":", 2) |
||||
method := methodArgs[0] |
||||
switch method { |
||||
case "default": |
||||
config.AwsCfgSrcType = blsgen.AwsCfgSrcShared |
||||
case "file": |
||||
config.AwsCfgSrcType = blsgen.AwsCfgSrcFile |
||||
if len(methodArgs) < 2 { |
||||
return blsgen.Config{}, errors.New("must specify aws config file") |
||||
} |
||||
config.AwsConfigFile = &methodArgs[1] |
||||
case "prompt": |
||||
config.AwsCfgSrcType = blsgen.AwsCfgSrcPrompt |
||||
case "none": |
||||
config.AwsCfgSrcType = blsgen.AwsCfgSrcNil |
||||
} |
||||
return config, nil |
||||
} |
@ -0,0 +1,54 @@ |
||||
package blsgen |
||||
|
||||
import ( |
||||
"bufio" |
||||
"fmt" |
||||
"os" |
||||
"strings" |
||||
"syscall" |
||||
|
||||
"golang.org/x/crypto/ssh/terminal" |
||||
) |
||||
|
||||
var console consoleItf = &stdConsole{} |
||||
|
||||
// consoleItf define the interface for module level console input and outputs
|
||||
type consoleItf interface { |
||||
readPassword() (string, error) |
||||
readln() (string, error) |
||||
print(a ...interface{}) |
||||
println(a ...interface{}) |
||||
printf(format string, a ...interface{}) |
||||
} |
||||
|
||||
type stdConsole struct{} |
||||
|
||||
func (console *stdConsole) readPassword() (string, error) { |
||||
b, err := terminal.ReadPassword(syscall.Stdin) |
||||
if err != nil { |
||||
return "", err |
||||
} |
||||
console.println() |
||||
return strings.TrimSpace(string(b)), nil |
||||
} |
||||
|
||||
func (console *stdConsole) readln() (string, error) { |
||||
reader := bufio.NewReader(os.Stdin) |
||||
raw, err := reader.ReadString('\n') |
||||
if err != nil { |
||||
return "", err |
||||
} |
||||
return strings.TrimSpace(raw), nil |
||||
} |
||||
|
||||
func (console *stdConsole) print(a ...interface{}) { |
||||
fmt.Print(a...) |
||||
} |
||||
|
||||
func (console *stdConsole) println(a ...interface{}) { |
||||
fmt.Println(a...) |
||||
} |
||||
|
||||
func (console *stdConsole) printf(format string, a ...interface{}) { |
||||
fmt.Printf(format, a...) |
||||
} |
@ -0,0 +1,76 @@ |
||||
package blsgen |
||||
|
||||
import ( |
||||
"errors" |
||||
"fmt" |
||||
"time" |
||||
) |
||||
|
||||
func setTestConsole(tc *testConsole) { |
||||
console = tc |
||||
} |
||||
|
||||
type testConsole struct { |
||||
In chan string |
||||
Out chan string |
||||
} |
||||
|
||||
func newTestConsole() *testConsole { |
||||
in := make(chan string, 100) |
||||
out := make(chan string, 100) |
||||
return &testConsole{in, out} |
||||
} |
||||
|
||||
func (tc *testConsole) readPassword() (string, error) { |
||||
return tc.readln() |
||||
} |
||||
|
||||
func (tc *testConsole) readln() (string, error) { |
||||
select { |
||||
case <-time.After(2 * time.Second): |
||||
return "", errors.New("timed out") |
||||
case msg, ok := <-tc.In: |
||||
if !ok { |
||||
return "", errors.New("in channel closed") |
||||
} |
||||
return msg, nil |
||||
} |
||||
} |
||||
|
||||
func (tc *testConsole) print(a ...interface{}) { |
||||
msg := fmt.Sprint(a...) |
||||
tc.Out <- msg |
||||
} |
||||
|
||||
func (tc *testConsole) println(a ...interface{}) { |
||||
msg := fmt.Sprintln(a...) |
||||
tc.Out <- msg |
||||
} |
||||
|
||||
func (tc *testConsole) printf(format string, a ...interface{}) { |
||||
msg := fmt.Sprintf(format, a...) |
||||
tc.Out <- msg |
||||
} |
||||
|
||||
func (tc *testConsole) checkClean() (bool, string) { |
||||
select { |
||||
case msg := <-tc.In: |
||||
return false, "extra in message: " + msg |
||||
case msg := <-tc.Out: |
||||
return false, "extra out message: " + msg |
||||
default: |
||||
return true, "" |
||||
} |
||||
} |
||||
|
||||
func (tc *testConsole) checkOutput(timeout time.Duration, checkFunc func(string) error) error { |
||||
if timeout == 0 { |
||||
timeout = 10 * time.Second |
||||
} |
||||
select { |
||||
case <-time.After(timeout): |
||||
return errors.New("timed out") |
||||
case msg := <-tc.Out: |
||||
return checkFunc(msg) |
||||
} |
||||
} |
@ -0,0 +1,95 @@ |
||||
package blsgen |
||||
|
||||
import ( |
||||
"fmt" |
||||
"os" |
||||
"path/filepath" |
||||
|
||||
bls_core "github.com/harmony-one/bls/ffi/go/bls" |
||||
"github.com/harmony-one/harmony/multibls" |
||||
) |
||||
|
||||
// loadHelper defines the helper interface to load bls keys. Implemented by
|
||||
// multiKeyLoader - load key files with a slice of target key files
|
||||
// blsDirLoader - load key files from a directory
|
||||
type loadHelper interface { |
||||
loadKeys() (multibls.PrivateKeys, error) |
||||
} |
||||
|
||||
// multiKeyLoader load keys from multiple bls key files
|
||||
type multiKeyLoader struct { |
||||
keyFiles []string |
||||
decrypters map[string]keyDecrypter |
||||
|
||||
loadedSecrets []*bls_core.SecretKey |
||||
} |
||||
|
||||
func newMultiKeyLoader(keyFiles []string, decrypters []keyDecrypter) (*multiKeyLoader, error) { |
||||
dm := make(map[string]keyDecrypter) |
||||
for _, decrypter := range decrypters { |
||||
dm[decrypter.extension()] = decrypter |
||||
} |
||||
for _, keyFile := range keyFiles { |
||||
ext := filepath.Ext(keyFile) |
||||
if _, supported := dm[ext]; !supported { |
||||
return nil, fmt.Errorf("unsupported key extension: %v", ext) |
||||
} |
||||
} |
||||
return &multiKeyLoader{ |
||||
keyFiles: keyFiles, |
||||
decrypters: dm, |
||||
loadedSecrets: make([]*bls_core.SecretKey, 0, len(keyFiles)), |
||||
}, nil |
||||
} |
||||
|
||||
func (loader *multiKeyLoader) loadKeys() (multibls.PrivateKeys, error) { |
||||
for _, keyFile := range loader.keyFiles { |
||||
decrypter := loader.decrypters[filepath.Ext(keyFile)] |
||||
secret, err := decrypter.decryptFile(keyFile) |
||||
if err != nil { |
||||
return multibls.PrivateKeys{}, err |
||||
} |
||||
loader.loadedSecrets = append(loader.loadedSecrets, secret) |
||||
} |
||||
return multibls.GetPrivateKeys(loader.loadedSecrets...), nil |
||||
} |
||||
|
||||
type blsDirLoader struct { |
||||
keyDir string |
||||
decrypters map[string]keyDecrypter |
||||
|
||||
loadedSecrets []*bls_core.SecretKey |
||||
} |
||||
|
||||
func newBlsDirLoader(keyDir string, decrypters []keyDecrypter) (*blsDirLoader, error) { |
||||
dm := make(map[string]keyDecrypter) |
||||
for _, decrypter := range decrypters { |
||||
dm[decrypter.extension()] = decrypter |
||||
} |
||||
if err := checkIsDir(keyDir); err != nil { |
||||
return nil, err |
||||
} |
||||
return &blsDirLoader{ |
||||
keyDir: keyDir, |
||||
decrypters: dm, |
||||
}, nil |
||||
} |
||||
|
||||
func (loader *blsDirLoader) loadKeys() (multibls.PrivateKeys, error) { |
||||
filepath.Walk(loader.keyDir, func(path string, info os.FileInfo, err error) error { |
||||
if err != nil { |
||||
return err |
||||
} |
||||
decrypter, exist := loader.decrypters[filepath.Ext(path)] |
||||
if !exist { |
||||
return nil |
||||
} |
||||
secret, err := decrypter.decryptFile(path) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
loader.loadedSecrets = append(loader.loadedSecrets, secret) |
||||
return nil |
||||
}) |
||||
return multibls.GetPrivateKeys(loader.loadedSecrets...), nil |
||||
} |
@ -0,0 +1,184 @@ |
||||
package blsgen |
||||
|
||||
import ( |
||||
"bytes" |
||||
"errors" |
||||
"os" |
||||
"path/filepath" |
||||
"testing" |
||||
|
||||
"github.com/ethereum/go-ethereum/common" |
||||
bls_core "github.com/harmony-one/bls/ffi/go/bls" |
||||
) |
||||
|
||||
const ( |
||||
testExt = ".test1" |
||||
) |
||||
|
||||
func TestNewMultiKeyLoader(t *testing.T) { |
||||
tests := []struct { |
||||
keyFiles []string |
||||
decrypters []keyDecrypter |
||||
expErr error |
||||
}{ |
||||
{ |
||||
keyFiles: []string{ |
||||
"test/keyfile1.key", |
||||
"keyfile2.bls", |
||||
}, |
||||
decrypters: []keyDecrypter{ |
||||
&passDecrypter{}, |
||||
&kmsDecrypter{}, |
||||
}, |
||||
expErr: nil, |
||||
}, |
||||
{ |
||||
keyFiles: []string{ |
||||
"test/keyfile1.key", |
||||
"keyfile2.bls", |
||||
}, |
||||
decrypters: []keyDecrypter{ |
||||
&passDecrypter{}, |
||||
}, |
||||
expErr: errors.New("unsupported key extension"), |
||||
}, |
||||
} |
||||
for i, test := range tests { |
||||
_, err := newMultiKeyLoader(test.keyFiles, test.decrypters) |
||||
if assErr := assertError(err, test.expErr); assErr != nil { |
||||
t.Errorf("Test %v: %v", i, assErr) |
||||
} |
||||
} |
||||
} |
||||
|
||||
func TestMultiKeyLoader_loadKeys(t *testing.T) { |
||||
setTestConsole(newTestConsole()) |
||||
|
||||
unitTestDir := filepath.Join(baseTestDir, t.Name()) |
||||
os.Remove(unitTestDir) |
||||
os.MkdirAll(unitTestDir, 0700) |
||||
|
||||
keyFile1 := filepath.Join(unitTestDir, testKeys[0].publicKey+basicKeyExt) |
||||
if err := writeFile(keyFile1, testKeys[0].keyFileData); err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
keyFile2 := filepath.Join(unitTestDir, testKeys[1].publicKey+testExt) |
||||
if err := writeFile(keyFile2, testKeys[1].keyFileData); err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
passFile1 := filepath.Join(unitTestDir, testKeys[0].publicKey+passExt) |
||||
if err := writeFile(passFile1, testKeys[0].passphrase); err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
decrypters := map[string]keyDecrypter{ |
||||
basicKeyExt: &passDecrypter{pps: []passProvider{&dynamicPassProvider{}}}, |
||||
testExt: newTestPassDecrypter(), |
||||
} |
||||
|
||||
loader := &multiKeyLoader{ |
||||
keyFiles: []string{keyFile1, keyFile2}, |
||||
decrypters: decrypters, |
||||
loadedSecrets: make([]*bls_core.SecretKey, 0, 2), |
||||
} |
||||
keys, err := loader.loadKeys() |
||||
if err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
if len(keys) != 2 { |
||||
t.Fatalf("unexpected number of keys: %v / 2", len(keys)) |
||||
} |
||||
gotPubs := [][]byte{ |
||||
keys[0].Pub.Bytes[:], |
||||
keys[1].Pub.Bytes[:], |
||||
} |
||||
expPubs := [][]byte{ |
||||
common.Hex2Bytes(testKeys[0].publicKey), |
||||
common.Hex2Bytes(testKeys[1].publicKey), |
||||
} |
||||
for i := range gotPubs { |
||||
got, exp := gotPubs[i], expPubs[i] |
||||
if !bytes.Equal(got, exp) { |
||||
t.Fatalf("%v pubkey unexpected: %x / %x", i, got, exp) |
||||
} |
||||
} |
||||
} |
||||
|
||||
func TestBlsDirLoader(t *testing.T) { |
||||
setTestConsole(newTestConsole()) |
||||
|
||||
unitTestDir := filepath.Join(baseTestDir, t.Name()) |
||||
os.Remove(unitTestDir) |
||||
os.MkdirAll(unitTestDir, 0700) |
||||
|
||||
keyFile1 := filepath.Join(unitTestDir, testKeys[0].publicKey+basicKeyExt) |
||||
if err := writeFile(keyFile1, testKeys[0].keyFileData); err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
keyFile2 := filepath.Join(unitTestDir, testKeys[1].publicKey+testExt) |
||||
if err := writeFile(keyFile2, testKeys[1].keyFileData); err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
passFile1 := filepath.Join(unitTestDir, testKeys[0].publicKey+passExt) |
||||
if err := writeFile(passFile1, testKeys[0].passphrase); err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
// write a file without the given extension
|
||||
if err := writeFile(filepath.Join(unitTestDir, "unknown.ext"), "random string"); err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
decrypters := []keyDecrypter{ |
||||
&passDecrypter{pps: []passProvider{&dynamicPassProvider{}}}, |
||||
newTestPassDecrypter(), |
||||
} |
||||
|
||||
loader, err := newBlsDirLoader(unitTestDir, decrypters) |
||||
if err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
keys, err := loader.loadKeys() |
||||
if err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
|
||||
if len(keys) != 2 { |
||||
t.Fatalf("unexpected number of keys: %v / 2", len(keys)) |
||||
} |
||||
gotPubs := [][]byte{ |
||||
keys[0].Pub.Bytes[:], |
||||
keys[1].Pub.Bytes[:], |
||||
} |
||||
expPubs := [][]byte{ |
||||
common.Hex2Bytes(testKeys[0].publicKey), |
||||
common.Hex2Bytes(testKeys[1].publicKey), |
||||
} |
||||
for i := range gotPubs { |
||||
got, exp := gotPubs[i], expPubs[i] |
||||
if !bytes.Equal(got, exp) { |
||||
t.Fatalf("%v pubkey unexpected: %x / %x", i, got, exp) |
||||
} |
||||
} |
||||
} |
||||
|
||||
type testPassDecrypter struct { |
||||
pd passDecrypter |
||||
} |
||||
|
||||
func newTestPassDecrypter() *testPassDecrypter { |
||||
provider := &testPassProvider{m: map[string]string{ |
||||
testKeys[0].publicKey: testKeys[0].passphrase, |
||||
testKeys[1].publicKey: testKeys[1].passphrase, |
||||
}} |
||||
return &testPassDecrypter{ |
||||
pd: passDecrypter{ |
||||
pps: []passProvider{provider}, |
||||
}, |
||||
} |
||||
} |
||||
|
||||
func (decrypter *testPassDecrypter) extension() string { |
||||
return testExt |
||||
} |
||||
|
||||
func (decrypter *testPassDecrypter) decryptFile(keyFile string) (*bls_core.SecretKey, error) { |
||||
return decrypter.pd.decryptFile(keyFile) |
||||
} |
@ -0,0 +1,279 @@ |
||||
package blsgen |
||||
|
||||
import ( |
||||
"encoding/json" |
||||
"fmt" |
||||
"io/ioutil" |
||||
"sync" |
||||
"time" |
||||
|
||||
"github.com/aws/aws-sdk-go/aws" |
||||
"github.com/aws/aws-sdk-go/aws/credentials" |
||||
"github.com/aws/aws-sdk-go/aws/session" |
||||
"github.com/aws/aws-sdk-go/service/kms" |
||||
bls_core "github.com/harmony-one/bls/ffi/go/bls" |
||||
"github.com/pkg/errors" |
||||
) |
||||
|
||||
// AwsCfgSrcType is the type of src to load aws config. Four options available:
|
||||
// AwsCfgSrcNil - Disable kms decryption
|
||||
// AwsCfgSrcFile - Provide the aws config through a file (json).
|
||||
// AwsCfgSrcPrompt - Provide the aws config though prompt.
|
||||
// AwsCfgSrcShared - Use the shard aws config (env -> default .aws directory)
|
||||
type AwsCfgSrcType uint8 |
||||
|
||||
const ( |
||||
// AwsCfgSrcNil is the nil place holder for AwsCfgSrcType.
|
||||
AwsCfgSrcNil AwsCfgSrcType = iota |
||||
// AwsCfgSrcFile instruct reading aws config through a json file.
|
||||
AwsCfgSrcFile |
||||
// AwsCfgSrcPrompt use a user interactive prompt to ge aws config.
|
||||
AwsCfgSrcPrompt |
||||
// AwsCfgSrcShared use shared AWS config and credentials from env and ~/.aws files.
|
||||
AwsCfgSrcShared |
||||
) |
||||
|
||||
func (srcType AwsCfgSrcType) isValid() bool { |
||||
switch srcType { |
||||
case AwsCfgSrcFile, AwsCfgSrcPrompt, AwsCfgSrcShared: |
||||
return true |
||||
default: |
||||
return false |
||||
} |
||||
} |
||||
|
||||
// kmsDecrypterConfig is the data structure of kmsClientProvider config
|
||||
type kmsDecrypterConfig struct { |
||||
awsCfgSrcType AwsCfgSrcType |
||||
awsConfigFile *string |
||||
} |
||||
|
||||
// kmsDecrypter provide the kms client with singleton lazy initialization with config get
|
||||
// from awsConfigProvider for aws credential and regions loading.
|
||||
type kmsDecrypter struct { |
||||
config kmsDecrypterConfig |
||||
|
||||
provider awsConfigProvider |
||||
client *kms.KMS |
||||
err error |
||||
once sync.Once |
||||
} |
||||
|
||||
// newKmsDecrypter creates a kmsDecrypter with the given config
|
||||
func newKmsDecrypter(config kmsDecrypterConfig) (*kmsDecrypter, error) { |
||||
kd := &kmsDecrypter{config: config} |
||||
if err := kd.validateConfig(); err != nil { |
||||
return nil, err |
||||
} |
||||
kd.makeACProvider() |
||||
return kd, nil |
||||
} |
||||
|
||||
// extension returns the kms key file extension
|
||||
func (kd *kmsDecrypter) extension() string { |
||||
return kmsKeyExt |
||||
} |
||||
|
||||
// decryptFile decrypt a kms key file to a secret key
|
||||
func (kd *kmsDecrypter) decryptFile(keyFile string) (*bls_core.SecretKey, error) { |
||||
kms, err := kd.getKMSClient() |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
return LoadAwsCMKEncryptedBLSKey(keyFile, kms) |
||||
} |
||||
|
||||
func (kd *kmsDecrypter) validateConfig() error { |
||||
config := kd.config |
||||
if !config.awsCfgSrcType.isValid() { |
||||
return errors.New("unknown AwsCfgSrcType") |
||||
} |
||||
if config.awsCfgSrcType == AwsCfgSrcFile { |
||||
if !stringIsSet(config.awsConfigFile) { |
||||
return errors.New("config field AwsConfig file must set for AwsCfgSrcFile") |
||||
} |
||||
if err := checkIsFile(*config.awsConfigFile); err != nil { |
||||
return err |
||||
} |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func (kd *kmsDecrypter) makeACProvider() { |
||||
config := kd.config |
||||
switch config.awsCfgSrcType { |
||||
case AwsCfgSrcFile: |
||||
kd.provider = newFileACProvider(*config.awsConfigFile) |
||||
case AwsCfgSrcPrompt: |
||||
kd.provider = newPromptACProvider(defKmsPromptTimeout) |
||||
case AwsCfgSrcShared: |
||||
kd.provider = newSharedAwsConfigProvider() |
||||
} |
||||
} |
||||
|
||||
func (kd *kmsDecrypter) getKMSClient() (*kms.KMS, error) { |
||||
kd.once.Do(func() { |
||||
cfg, err := kd.provider.getAwsConfig() |
||||
if err != nil { |
||||
kd.err = err |
||||
return |
||||
} |
||||
kd.client, kd.err = kmsClientWithConfig(cfg) |
||||
}) |
||||
if kd.err != nil { |
||||
return nil, kd.err |
||||
} |
||||
return kd.client, nil |
||||
} |
||||
|
||||
// AwsConfig is the config data structure for credentials and region. Used for AWS KMS
|
||||
// decryption.
|
||||
type AwsConfig struct { |
||||
AccessKey string `json:"aws-access-key-id"` |
||||
SecretKey string `json:"aws-secret-access-key"` |
||||
Region string `json:"aws-region"` |
||||
Token string `json:"aws-token,omitempty"` |
||||
} |
||||
|
||||
func (cfg AwsConfig) toAws() *aws.Config { |
||||
cred := credentials.NewStaticCredentials(cfg.AccessKey, cfg.SecretKey, cfg.Token) |
||||
return &aws.Config{ |
||||
Region: aws.String(cfg.Region), |
||||
Credentials: cred, |
||||
} |
||||
} |
||||
|
||||
// awsConfigProvider provides the aws config. Implemented by
|
||||
// sharedACProvider - provide the nil to use shared AWS configuration
|
||||
// fileACProvider - provide the aws config with a json file
|
||||
// promptACProvider - provide the config field from prompt with time out
|
||||
// TODO: load aws session set up in a more official way. E.g. session.Opt.SharedConfigFiles,
|
||||
// profile, env, e.t.c.
|
||||
type awsConfigProvider interface { |
||||
getAwsConfig() (*AwsConfig, error) |
||||
} |
||||
|
||||
// sharedACProvider returns nil for getAwsConfig to use shared aws configurations
|
||||
type sharedACProvider struct{} |
||||
|
||||
func newSharedAwsConfigProvider() *sharedACProvider { |
||||
return &sharedACProvider{} |
||||
} |
||||
|
||||
func (provider *sharedACProvider) getAwsConfig() (*AwsConfig, error) { |
||||
return nil, nil |
||||
} |
||||
|
||||
// fileACProvider get aws config through a customized json file
|
||||
type fileACProvider struct { |
||||
file string |
||||
} |
||||
|
||||
func newFileACProvider(file string) *fileACProvider { |
||||
return &fileACProvider{file} |
||||
} |
||||
|
||||
func (provider *fileACProvider) getAwsConfig() (*AwsConfig, error) { |
||||
b, err := ioutil.ReadFile(provider.file) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
var cfg AwsConfig |
||||
if err := json.Unmarshal(b, &cfg); err != nil { |
||||
return nil, err |
||||
} |
||||
return &cfg, nil |
||||
} |
||||
|
||||
// promptACProvider provide a user interactive console for AWS config.
|
||||
// Four fields are asked:
|
||||
// 1. AccessKey 2. SecretKey 3. Region
|
||||
// Each field is asked with a timeout mechanism.
|
||||
type promptACProvider struct { |
||||
timeout time.Duration |
||||
} |
||||
|
||||
func newPromptACProvider(timeout time.Duration) *promptACProvider { |
||||
return &promptACProvider{ |
||||
timeout: timeout, |
||||
} |
||||
} |
||||
|
||||
func (provider *promptACProvider) getAwsConfig() (*AwsConfig, error) { |
||||
console.println("Please provide AWS configurations for KMS encoded BLS keys:") |
||||
accessKey, err := provider.prompt(" AccessKey:") |
||||
if err != nil { |
||||
return nil, fmt.Errorf("cannot get aws access key: %v", err) |
||||
} |
||||
secretKey, err := provider.prompt(" SecretKey:") |
||||
if err != nil { |
||||
return nil, fmt.Errorf("cannot get aws secret key: %v", err) |
||||
} |
||||
region, err := provider.prompt(" Region:") |
||||
if err != nil { |
||||
return nil, fmt.Errorf("cannot get aws region: %v", err) |
||||
} |
||||
return &AwsConfig{ |
||||
AccessKey: accessKey, |
||||
SecretKey: secretKey, |
||||
Region: region, |
||||
Token: "", |
||||
}, nil |
||||
} |
||||
|
||||
// prompt prompt the user to input a string for a certain field with timeout.
|
||||
func (provider *promptACProvider) prompt(hint string) (string, error) { |
||||
var ( |
||||
res string |
||||
err error |
||||
|
||||
finished = make(chan struct{}) |
||||
timedOut = time.After(provider.timeout) |
||||
) |
||||
|
||||
cs := console |
||||
go func() { |
||||
res, err = provider.threadedPrompt(cs, hint) |
||||
close(finished) |
||||
}() |
||||
|
||||
for { |
||||
select { |
||||
case <-finished: |
||||
return res, err |
||||
case <-timedOut: |
||||
console.println("ERROR input time out") |
||||
return "", errors.New("timed out") |
||||
} |
||||
} |
||||
} |
||||
|
||||
func (provider *promptACProvider) threadedPrompt(cs consoleItf, hint string) (string, error) { |
||||
cs.print(hint) |
||||
return cs.readPassword() |
||||
} |
||||
|
||||
func kmsClientWithConfig(config *AwsConfig) (*kms.KMS, error) { |
||||
if config == nil { |
||||
return getSharedKMSClient() |
||||
} |
||||
return getKMSClientFromConfig(*config) |
||||
} |
||||
|
||||
func getSharedKMSClient() (*kms.KMS, error) { |
||||
sess, err := session.NewSessionWithOptions(session.Options{ |
||||
SharedConfigState: session.SharedConfigEnable, |
||||
}) |
||||
if err != nil { |
||||
return nil, errors.Wrapf(err, "failed to create aws session") |
||||
} |
||||
return kms.New(sess), err |
||||
} |
||||
|
||||
func getKMSClientFromConfig(config AwsConfig) (*kms.KMS, error) { |
||||
sess, err := session.NewSession(config.toAws()) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
return kms.New(sess), nil |
||||
} |
@ -0,0 +1,271 @@ |
||||
package blsgen |
||||
|
||||
import ( |
||||
"encoding/hex" |
||||
"encoding/json" |
||||
"errors" |
||||
"fmt" |
||||
"io/ioutil" |
||||
"os" |
||||
"path/filepath" |
||||
"reflect" |
||||
"testing" |
||||
"time" |
||||
|
||||
"github.com/aws/aws-sdk-go/service/kms" |
||||
"github.com/ethereum/go-ethereum/common" |
||||
ffi_bls "github.com/harmony-one/bls/ffi/go/bls" |
||||
"github.com/harmony-one/harmony/crypto/bls" |
||||
) |
||||
|
||||
var TestAwsConfig = AwsConfig{ |
||||
AccessKey: "access key", |
||||
SecretKey: "secret key", |
||||
Region: "region", |
||||
} |
||||
|
||||
func TestNewKmsDecrypter(t *testing.T) { |
||||
unitTestDir := filepath.Join(baseTestDir, t.Name()) |
||||
testFile := filepath.Join(unitTestDir, "test.json") |
||||
if err := writeAwsConfigFile(testFile, TestAwsConfig); err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
emptyFile := filepath.Join(unitTestDir, "empty.json") |
||||
|
||||
tests := []struct { |
||||
config kmsDecrypterConfig |
||||
expProvider awsConfigProvider |
||||
expErr error |
||||
}{ |
||||
{ |
||||
config: kmsDecrypterConfig{ |
||||
awsCfgSrcType: AwsCfgSrcNil, |
||||
}, |
||||
expErr: errors.New("unknown AwsCfgSrcType"), |
||||
}, |
||||
{ |
||||
config: kmsDecrypterConfig{ |
||||
awsCfgSrcType: AwsCfgSrcShared, |
||||
}, |
||||
expProvider: &sharedACProvider{}, |
||||
}, |
||||
{ |
||||
config: kmsDecrypterConfig{ |
||||
awsCfgSrcType: AwsCfgSrcPrompt, |
||||
}, |
||||
expProvider: &promptACProvider{}, |
||||
}, |
||||
{ |
||||
config: kmsDecrypterConfig{ |
||||
awsCfgSrcType: AwsCfgSrcFile, |
||||
awsConfigFile: &testFile, |
||||
}, |
||||
expProvider: &fileACProvider{}, |
||||
}, |
||||
{ |
||||
config: kmsDecrypterConfig{ |
||||
awsCfgSrcType: AwsCfgSrcFile, |
||||
}, |
||||
expErr: errors.New("config field AwsConfig file must set for AwsCfgSrcFile"), |
||||
}, |
||||
{ |
||||
config: kmsDecrypterConfig{ |
||||
awsCfgSrcType: AwsCfgSrcFile, |
||||
awsConfigFile: &emptyFile, |
||||
}, |
||||
expErr: errors.New("no such file"), |
||||
}, |
||||
} |
||||
for i, test := range tests { |
||||
kd, err := newKmsDecrypter(test.config) |
||||
|
||||
if assErr := assertError(err, test.expErr); assErr != nil { |
||||
t.Errorf("Test %v: %v", i, assErr) |
||||
continue |
||||
} |
||||
if err != nil || test.expErr != nil { |
||||
continue |
||||
} |
||||
gotType := reflect.TypeOf(kd.provider).Elem() |
||||
expType := reflect.TypeOf(test.expProvider).Elem() |
||||
if gotType != expType { |
||||
t.Errorf("Test %v: unexpected aws config provider type: %v / %v", |
||||
i, gotType, expType) |
||||
} |
||||
} |
||||
} |
||||
|
||||
func writeAwsConfigFile(file string, config AwsConfig) error { |
||||
b, err := json.Marshal(config) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
if _, err := os.Stat(filepath.Dir(file)); err != nil { |
||||
if os.IsNotExist(err) { |
||||
os.MkdirAll(filepath.Dir(file), 0700) |
||||
} else { |
||||
return err |
||||
} |
||||
} |
||||
return ioutil.WriteFile(file, b, 0700) |
||||
} |
||||
|
||||
func TestPromptACProvider_getAwsConfig(t *testing.T) { |
||||
tc := newTestConsole() |
||||
setTestConsole(tc) |
||||
|
||||
for _, input := range []string{ |
||||
TestAwsConfig.AccessKey, |
||||
TestAwsConfig.SecretKey, |
||||
TestAwsConfig.Region, |
||||
} { |
||||
tc.In <- input |
||||
} |
||||
provider := newPromptACProvider(1 * time.Second) |
||||
got, err := provider.getAwsConfig() |
||||
if err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
if !reflect.DeepEqual(*got, TestAwsConfig) { |
||||
t.Errorf("unexpected result %+v / %+v", got, TestAwsConfig) |
||||
} |
||||
} |
||||
|
||||
func TestPromptACProvider_prompt(t *testing.T) { |
||||
tests := []struct { |
||||
delay, timeout time.Duration |
||||
expErr error |
||||
}{ |
||||
{ |
||||
delay: 100 * time.Microsecond, |
||||
timeout: 1000 * time.Microsecond, |
||||
expErr: nil, |
||||
}, |
||||
{ |
||||
delay: 2000 * time.Microsecond, |
||||
timeout: 1000 * time.Microsecond, |
||||
expErr: errors.New("timed out"), |
||||
}, |
||||
} |
||||
for i, test := range tests { |
||||
tc := newTestConsole() |
||||
setTestConsole(tc) |
||||
|
||||
testInput := "test" |
||||
go func() { |
||||
<-time.After(test.delay) |
||||
tc.In <- testInput |
||||
}() |
||||
provider := newPromptACProvider(test.timeout) |
||||
got, err := provider.prompt("test ask string") |
||||
|
||||
if assErr := assertError(err, test.expErr); assErr != nil { |
||||
t.Errorf("Test %v: %v", i, assErr) |
||||
continue |
||||
} |
||||
if err != nil { |
||||
continue |
||||
} |
||||
if got != testInput { |
||||
t.Errorf("Test %v: unexpected prompt result: %v / %v", i, got, testInput) |
||||
} |
||||
} |
||||
} |
||||
|
||||
func TestFileACProvider_getAwsConfig(t *testing.T) { |
||||
jsonBytes, err := json.Marshal(TestAwsConfig) |
||||
if err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
unitTestDir := filepath.Join(baseTestDir, t.Name()) |
||||
|
||||
tests := []struct { |
||||
setupFunc func(rootDir string) error |
||||
expConfig AwsConfig |
||||
jsonFile string |
||||
expErr error |
||||
}{ |
||||
{ |
||||
// positive
|
||||
setupFunc: func(rootDir string) error { |
||||
jsonFile := filepath.Join(rootDir, "valid.json") |
||||
return writeFile(jsonFile, string(jsonBytes)) |
||||
}, |
||||
jsonFile: "valid.json", |
||||
expConfig: TestAwsConfig, |
||||
}, |
||||
{ |
||||
// no such file
|
||||
setupFunc: nil, |
||||
jsonFile: "empty.json", |
||||
expErr: errors.New("no such file"), |
||||
}, |
||||
{ |
||||
// invalid json string
|
||||
setupFunc: func(rootDir string) error { |
||||
jsonFile := filepath.Join(rootDir, "invalid.json") |
||||
return writeFile(jsonFile, string(jsonBytes[:len(jsonBytes)-2])) |
||||
}, |
||||
jsonFile: "invalid.json", |
||||
expErr: errors.New("unexpected end of JSON input"), |
||||
}, |
||||
} |
||||
for i, test := range tests { |
||||
tcDir := filepath.Join(unitTestDir, fmt.Sprintf("%v", i)) |
||||
os.RemoveAll(tcDir) |
||||
os.MkdirAll(tcDir, 0700) |
||||
if test.setupFunc != nil { |
||||
if err := test.setupFunc(tcDir); err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
} |
||||
|
||||
provider := newFileACProvider(filepath.Join(tcDir, test.jsonFile)) |
||||
got, err := provider.getAwsConfig() |
||||
|
||||
if assErr := assertError(err, test.expErr); assErr != nil { |
||||
t.Errorf("Test %v: %v", i, assErr) |
||||
} |
||||
if err != nil || test.expErr != nil { |
||||
continue |
||||
} |
||||
if got == nil || !reflect.DeepEqual(*got, test.expConfig) { |
||||
t.Errorf("Test %v: unexpected AwsConfig: %+v / %+v", i, |
||||
got, test.expConfig) |
||||
} |
||||
} |
||||
} |
||||
|
||||
// This is the learning test for kms encryption and decryption. This is just to illustrate
|
||||
// To successfully run this test, need to set the AWS default configuration and set up kms
|
||||
// key and replace keyId field.
|
||||
func TestKMSEncryption(t *testing.T) { |
||||
t.SkipNow() |
||||
client, err := getSharedKMSClient() |
||||
if err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
privHex := testKeys[0].privateKey |
||||
keyID := "26adbb7b-6c46-4763-a7b3-de7ee768890a" // Replace your key ID here
|
||||
|
||||
output, err := client.Encrypt(&kms.EncryptInput{ |
||||
KeyId: &keyID, |
||||
Plaintext: common.Hex2Bytes(privHex), |
||||
}) |
||||
if err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
fmt.Printf("encrypted: [%x]\n", output.CiphertextBlob) |
||||
decryted, err := client.Decrypt(&kms.DecryptInput{CiphertextBlob: output.CiphertextBlob}) |
||||
if err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
priKey := &ffi_bls.SecretKey{} |
||||
if err = priKey.DeserializeHexStr(hex.EncodeToString(decryted.Plaintext)); err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
pubKey := bls.FromLibBLSPublicKeyUnsafe(priKey.GetPublicKey()) |
||||
if hex.EncodeToString(pubKey[:]) != testKeys[0].publicKey { |
||||
t.Errorf("unexpected public key") |
||||
} |
||||
} |
@ -0,0 +1,119 @@ |
||||
package blsgen |
||||
|
||||
import ( |
||||
"errors" |
||||
"fmt" |
||||
|
||||
bls_core "github.com/harmony-one/bls/ffi/go/bls" |
||||
"github.com/harmony-one/harmony/multibls" |
||||
) |
||||
|
||||
// LoadKeys load all BLS keys with the given config. If loading keys from files, the
|
||||
// file extension will decide which decryption algorithm to use.
|
||||
func LoadKeys(cfg Config) (multibls.PrivateKeys, error) { |
||||
decrypters, err := getKeyDecrypters(cfg) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
helper, err := getHelper(cfg, decrypters) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
return helper.loadKeys() |
||||
} |
||||
|
||||
// Config is the config structure for LoadKeys.
|
||||
type Config struct { |
||||
// source for bls key loading. At least one of the MultiBlsKeys and BlsDir
|
||||
// need to be provided.
|
||||
//
|
||||
// MultiBlsKeys defines a slice of key files to load from.
|
||||
MultiBlsKeys []string |
||||
// BlsDir defines a file directory to load keys from.
|
||||
BlsDir *string |
||||
|
||||
// Passphrase related settings. Used for passphrase encrypted key files.
|
||||
//
|
||||
// PassSrcType defines the source to get passphrase. Three source types are available
|
||||
// PassSrcNil - do not use passphrase decryption
|
||||
// PassSrcFile - get passphrase from a .pass file
|
||||
// PassSrcPrompt - get passphrase from prompt
|
||||
// PassSrcAuto - try to unlock with .pass file. If not success, ask user with prompt
|
||||
PassSrcType PassSrcType |
||||
// PassFile specifies the .pass file to be used when loading passphrase from file.
|
||||
// If not set, default to the .pass file in the same directory as the key file.
|
||||
PassFile *string |
||||
// PersistPassphrase set whether to persist the passphrase to a .pass file when
|
||||
// prompt the user for passphrase. Persisted pass file is a file with .pass extension
|
||||
// under the same directory as the key file.
|
||||
PersistPassphrase bool |
||||
|
||||
// KMS related settings, including AWS credentials and region info.
|
||||
// Used for KMS encrypted passphrase files.
|
||||
//
|
||||
// AwsCfgSrcType defines the source to get aws config. Three types available:
|
||||
// AwsCfgSrcNil - do not use Aws KMS decryption service.
|
||||
// AwsCfgSrcFile - get AWS config through a json file. See AwsConfig for content fields.
|
||||
// AwsCfgSrcPrompt - get AWS config through prompt.
|
||||
// AwsCfgSrcShared - Use the default AWS config settings (from env and $HOME/.aws/config)
|
||||
AwsCfgSrcType AwsCfgSrcType |
||||
// AwsConfigFile set the json file to load aws config.
|
||||
AwsConfigFile *string |
||||
} |
||||
|
||||
func (cfg *Config) getPassProviderConfig() passDecrypterConfig { |
||||
return passDecrypterConfig{ |
||||
passSrcType: cfg.PassSrcType, |
||||
passFile: cfg.PassFile, |
||||
persistPassphrase: cfg.PersistPassphrase, |
||||
} |
||||
} |
||||
|
||||
func (cfg *Config) getKmsProviderConfig() kmsDecrypterConfig { |
||||
return kmsDecrypterConfig{ |
||||
awsCfgSrcType: cfg.AwsCfgSrcType, |
||||
awsConfigFile: cfg.AwsConfigFile, |
||||
} |
||||
} |
||||
|
||||
// keyDecrypter is the interface to decrypt the bls key file. Currently, two
|
||||
// implementations are supported:
|
||||
// passDecrypter - decrypt with passphrase for file name with extension .key
|
||||
// kmsDecrypter - decrypt with aws kms service for file name with extension .bls
|
||||
type keyDecrypter interface { |
||||
extension() string |
||||
decryptFile(keyFile string) (*bls_core.SecretKey, error) |
||||
} |
||||
|
||||
func getKeyDecrypters(cfg Config) ([]keyDecrypter, error) { |
||||
var decrypters []keyDecrypter |
||||
if cfg.PassSrcType != PassSrcNil { |
||||
pd, err := newPassDecrypter(cfg.getPassProviderConfig()) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
decrypters = append(decrypters, pd) |
||||
} |
||||
if cfg.AwsCfgSrcType != AwsCfgSrcNil { |
||||
kd, err := newKmsDecrypter(cfg.getKmsProviderConfig()) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
decrypters = append(decrypters, kd) |
||||
} |
||||
if len(decrypters) == 0 { |
||||
return nil, fmt.Errorf("must provide at least one bls key decryption") |
||||
} |
||||
return decrypters, nil |
||||
} |
||||
|
||||
func getHelper(cfg Config, decrypters []keyDecrypter) (loadHelper, error) { |
||||
switch { |
||||
case len(cfg.MultiBlsKeys) != 0: |
||||
return newMultiKeyLoader(cfg.MultiBlsKeys, decrypters) |
||||
case stringIsSet(cfg.BlsDir): |
||||
return newBlsDirLoader(*cfg.BlsDir, decrypters) |
||||
default: |
||||
return nil, errors.New("either MultiBlsKeys or BlsDir must be set") |
||||
} |
||||
} |
@ -0,0 +1,111 @@ |
||||
package blsgen |
||||
|
||||
import ( |
||||
"errors" |
||||
"fmt" |
||||
"os" |
||||
"path/filepath" |
||||
"reflect" |
||||
"testing" |
||||
) |
||||
|
||||
func ExampleLoadKeys() { |
||||
dir, err := prepareDataForExample() |
||||
if err != nil { |
||||
fmt.Println(err) |
||||
return |
||||
} |
||||
config := Config{ |
||||
BlsDir: &dir, |
||||
PassSrcType: PassSrcFile, // not assign PassFile to dynamically use .pass path
|
||||
AwsCfgSrcType: AwsCfgSrcNil, // disable loading file with kms
|
||||
} |
||||
|
||||
keys, err := LoadKeys(config) |
||||
if err != nil { |
||||
fmt.Println(err) |
||||
return |
||||
} |
||||
|
||||
fmt.Printf("loaded %v keys\n", len(keys)) |
||||
for i, key := range keys { |
||||
fmt.Printf(" key %v: %x\n", i, key.Pub.Bytes) |
||||
} |
||||
// Output:
|
||||
//
|
||||
// loaded 2 keys
|
||||
// key 0: 0e969f8b302cf7648bc39652ca7a279a8562b72933a3f7cddac2252583280c7c3495c9ae854f00f6dd19c32fc5a17500
|
||||
// key 1: 152beed46d7a0002ef0f960946008887eedd4775bdf2ed238809aa74e20d31fdca267443615cc6f4ede49d58911ee083
|
||||
} |
||||
|
||||
func prepareDataForExample() (string, error) { |
||||
unitTestDir := filepath.Join(baseTestDir, "ExampleLoadKeys") |
||||
os.Remove(unitTestDir) |
||||
os.MkdirAll(unitTestDir, 0700) |
||||
|
||||
if err := writeKeyAndPass(unitTestDir, testKeys[0]); err != nil { |
||||
return "", err |
||||
} |
||||
if err := writeKeyAndPass(unitTestDir, testKeys[1]); err != nil { |
||||
return "", err |
||||
} |
||||
return unitTestDir, nil |
||||
} |
||||
|
||||
func writeKeyAndPass(dir string, key testKey) error { |
||||
keyFile := filepath.Join(dir, key.publicKey+basicKeyExt) |
||||
if err := writeFile(keyFile, key.keyFileData); err != nil { |
||||
return fmt.Errorf("cannot write key file data: %v", err) |
||||
} |
||||
passFile := filepath.Join(dir, key.publicKey+passExt) |
||||
if err := writeFile(passFile, key.passphrase); err != nil { |
||||
return fmt.Errorf("cannot write pass file data: %v", err) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func TestGetKeyDecrypters(t *testing.T) { |
||||
tests := []struct { |
||||
config Config |
||||
expTypes []keyDecrypter |
||||
expErr error |
||||
}{ |
||||
{ |
||||
config: Config{ |
||||
PassSrcType: PassSrcNil, |
||||
AwsCfgSrcType: AwsCfgSrcNil, |
||||
}, |
||||
expErr: errors.New("must provide at least one bls key decryption"), |
||||
}, |
||||
{ |
||||
config: Config{ |
||||
PassSrcType: PassSrcFile, |
||||
AwsCfgSrcType: AwsCfgSrcShared, |
||||
}, |
||||
expTypes: []keyDecrypter{ |
||||
&passDecrypter{}, |
||||
&kmsDecrypter{}, |
||||
}, |
||||
}, |
||||
} |
||||
for i, test := range tests { |
||||
decrypters, err := getKeyDecrypters(test.config) |
||||
if assErr := assertError(err, test.expErr); assErr != nil { |
||||
t.Errorf("Test %v: %v", i, assErr) |
||||
} |
||||
if err != nil || test.expErr != nil { |
||||
continue |
||||
} |
||||
if len(decrypters) != len(test.expTypes) { |
||||
t.Errorf("Test %v: unexpected decrypter size: %v / %v", i, len(decrypters), len(test.expTypes)) |
||||
continue |
||||
} |
||||
for ti := range decrypters { |
||||
gotType := reflect.TypeOf(decrypters[ti]).Elem() |
||||
expType := reflect.TypeOf(test.expTypes[ti]).Elem() |
||||
if gotType != expType { |
||||
t.Errorf("Test %v: %v decrypter type unexpected: %v / %v", i, ti, gotType, expType) |
||||
} |
||||
} |
||||
} |
||||
} |
@ -0,0 +1,20 @@ |
||||
package blsgen |
||||
|
||||
import "time" |
||||
|
||||
const ( |
||||
// Extensions for files.
|
||||
passExt = ".pass" |
||||
basicKeyExt = ".key" |
||||
kmsKeyExt = ".bls" |
||||
) |
||||
|
||||
const ( |
||||
// The default timeout for kms config prompt. The timeout is introduced
|
||||
// for security concern.
|
||||
defKmsPromptTimeout = 60 * time.Second |
||||
) |
||||
|
||||
const ( |
||||
defWritePassFileMode = 0600 |
||||
) |
@ -0,0 +1,228 @@ |
||||
package blsgen |
||||
|
||||
import ( |
||||
"errors" |
||||
"fmt" |
||||
"io/ioutil" |
||||
"os" |
||||
"strings" |
||||
"sync" |
||||
|
||||
bls_core "github.com/harmony-one/bls/ffi/go/bls" |
||||
) |
||||
|
||||
// PassSrcType is the type of passphrase provider source.
|
||||
// Four options available:
|
||||
// PassSrcNil - Do not use passphrase decryption
|
||||
// PassSrcFile - Read the passphrase from files
|
||||
// PassSrcPrompt - Read the passphrase from prompt
|
||||
// PassSrcAuto - First try to unlock with passphrase from file, then read passphrase from prompt
|
||||
type PassSrcType uint8 |
||||
|
||||
const ( |
||||
// PassSrcNil is place holder for nil src
|
||||
PassSrcNil PassSrcType = iota |
||||
// PassSrcFile provide the passphrase through pass files
|
||||
PassSrcFile |
||||
// PassSrcPrompt provide the passphrase through prompt
|
||||
PassSrcPrompt |
||||
// PassSrcAuto first try to unlock with pass from file, then look for prompt
|
||||
PassSrcAuto |
||||
) |
||||
|
||||
func (srcType PassSrcType) isValid() bool { |
||||
switch srcType { |
||||
case PassSrcAuto, PassSrcFile, PassSrcPrompt: |
||||
return true |
||||
default: |
||||
return false |
||||
} |
||||
} |
||||
|
||||
// passDecrypterConfig is the data structure of passProviders config
|
||||
type passDecrypterConfig struct { |
||||
passSrcType PassSrcType |
||||
passFile *string |
||||
persistPassphrase bool |
||||
} |
||||
|
||||
// passDecrypter decrypt the .key bls files with passphrase from a series
|
||||
// of passProvider as passphrase source
|
||||
type passDecrypter struct { |
||||
config passDecrypterConfig |
||||
|
||||
pps []passProvider |
||||
} |
||||
|
||||
func newPassDecrypter(cfg passDecrypterConfig) (*passDecrypter, error) { |
||||
pd := &passDecrypter{config: cfg} |
||||
if err := pd.validateConfig(); err != nil { |
||||
return nil, err |
||||
} |
||||
pd.makePassProviders() |
||||
return pd, nil |
||||
} |
||||
|
||||
func (pd *passDecrypter) extension() string { |
||||
return basicKeyExt |
||||
} |
||||
|
||||
func (pd *passDecrypter) decryptFile(keyFile string) (*bls_core.SecretKey, error) { |
||||
for _, pp := range pd.pps { |
||||
secretKey, err := loadBasicKeyWithProvider(keyFile, pp) |
||||
if err != nil { |
||||
console.println(err) |
||||
continue |
||||
} |
||||
return secretKey, nil |
||||
} |
||||
return nil, fmt.Errorf("failed to load bls key %v", keyFile) |
||||
} |
||||
|
||||
func (pd *passDecrypter) validateConfig() error { |
||||
config := pd.config |
||||
if !config.passSrcType.isValid() { |
||||
return errors.New("unknown PassSrcType") |
||||
} |
||||
if stringIsSet(config.passFile) { |
||||
if err := checkIsFile(*config.passFile); err != nil { |
||||
return err |
||||
} |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func (pd *passDecrypter) makePassProviders() { |
||||
switch pd.config.passSrcType { |
||||
case PassSrcFile: |
||||
pd.pps = []passProvider{pd.getFilePassProvider()} |
||||
case PassSrcPrompt: |
||||
pd.pps = []passProvider{pd.getPromptPassProvider()} |
||||
case PassSrcAuto: |
||||
pd.pps = []passProvider{ |
||||
pd.getFilePassProvider(), |
||||
pd.getPromptPassProvider(), |
||||
} |
||||
} |
||||
} |
||||
|
||||
func (pd *passDecrypter) getPromptPassProvider() passProvider { |
||||
return newPromptPassProvider(pd.config.persistPassphrase) |
||||
} |
||||
|
||||
func (pd *passDecrypter) getFilePassProvider() passProvider { |
||||
switch { |
||||
case stringIsSet(pd.config.passFile): |
||||
return newStaticPassProvider(*pd.config.passFile) |
||||
default: |
||||
return newDynamicPassProvider() |
||||
} |
||||
} |
||||
|
||||
// passProvider is the interface to provide the passphrase of a bls keys.
|
||||
// Implemented by
|
||||
// promptPassProvider - provide passphrase through user-interactive prompt
|
||||
// staticPassProvider - provide passphrase from a static .pass file
|
||||
// dynamicPassProvider - provide the passphrase based on the given key file keyFile
|
||||
// dirPassProvider - provide passphrase from .pass files in a directory
|
||||
type passProvider interface { |
||||
getPassphrase(keyFile string) (string, error) |
||||
} |
||||
|
||||
// promptPassProvider provides the bls passphrase through console prompt.
|
||||
type promptPassProvider struct { |
||||
// if enablePersist is true, after user enter the passphrase, the
|
||||
// passphrase is also persisted into .pass file under the same directory
|
||||
// of the key file
|
||||
enablePersist bool |
||||
} |
||||
|
||||
const pwdPromptStr = "Enter passphrase for the BLS key file %s:" |
||||
|
||||
func newPromptPassProvider(enablePersist bool) *promptPassProvider { |
||||
return &promptPassProvider{enablePersist: enablePersist} |
||||
} |
||||
|
||||
func (provider *promptPassProvider) getPassphrase(keyFile string) (string, error) { |
||||
prompt := fmt.Sprintf(pwdPromptStr, keyFile) |
||||
pass, err := promptGetPassword(prompt) |
||||
if err != nil { |
||||
return "", fmt.Errorf("unable to read from prompt: %v", err) |
||||
} |
||||
pass = strings.TrimSpace(pass) |
||||
// If user set to persist the pass file, persist to .pass file
|
||||
if provider.enablePersist { |
||||
if err := provider.persistPassphrase(keyFile, pass); err != nil { |
||||
return "", fmt.Errorf("unable to save passphrase: %v", err) |
||||
} |
||||
} |
||||
return pass, nil |
||||
} |
||||
|
||||
func (provider *promptPassProvider) persistPassphrase(keyFile string, passPhrase string) error { |
||||
passFile := keyFileToPassFileFull(keyFile) |
||||
if _, err := os.Stat(passFile); err == nil { |
||||
// File exist. Prompt user to overwrite pass file
|
||||
overwrite, err := promptYesNo(fmt.Sprintf("pass file [%v] already exist. Overwrite? ", passFile)) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
if !overwrite { |
||||
return nil |
||||
} |
||||
} else if !os.IsNotExist(err) { |
||||
// Unknown error. Directly return
|
||||
return err |
||||
} |
||||
|
||||
return ioutil.WriteFile(passFile, []byte(passPhrase), defWritePassFileMode) |
||||
} |
||||
|
||||
// staticPassProvider provide the bls passphrase from a static .pass file
|
||||
type staticPassProvider struct { |
||||
fileName string |
||||
|
||||
// cached field
|
||||
pass string |
||||
err error |
||||
once sync.Once |
||||
} |
||||
|
||||
func newStaticPassProvider(fileName string) *staticPassProvider { |
||||
return &staticPassProvider{fileName: fileName} |
||||
} |
||||
|
||||
func (provider *staticPassProvider) getPassphrase(keyFile string) (string, error) { |
||||
provider.once.Do(func() { |
||||
provider.pass, provider.err = readPassFromFile(provider.fileName) |
||||
}) |
||||
return provider.pass, provider.err |
||||
} |
||||
|
||||
// dynamicPassProvider provide the passphrase based on .pass file with the given
|
||||
// key file keyFile. For example, looking for key file xxx.key will provide the
|
||||
// passphrase from xxx.pass
|
||||
type dynamicPassProvider struct{} |
||||
|
||||
func newDynamicPassProvider() passProvider { |
||||
return &dynamicPassProvider{} |
||||
} |
||||
|
||||
func (provider *dynamicPassProvider) getPassphrase(keyFile string) (string, error) { |
||||
passFile := keyFileToPassFileFull(keyFile) |
||||
return readPassFromFile(passFile) |
||||
} |
||||
|
||||
func readPassFromFile(file string) (string, error) { |
||||
f, err := os.Open(file) |
||||
if err != nil { |
||||
return "", err |
||||
} |
||||
defer f.Close() |
||||
|
||||
b, err := ioutil.ReadAll(f) |
||||
if err != nil { |
||||
return "", err |
||||
} |
||||
return strings.TrimSpace(string(b)), nil |
||||
} |
@ -0,0 +1,381 @@ |
||||
package blsgen |
||||
|
||||
import ( |
||||
"bytes" |
||||
"errors" |
||||
"fmt" |
||||
"io/ioutil" |
||||
"os" |
||||
"path/filepath" |
||||
"reflect" |
||||
"strings" |
||||
"testing" |
||||
|
||||
"github.com/ethereum/go-ethereum/common" |
||||
"github.com/harmony-one/harmony/crypto/bls" |
||||
) |
||||
|
||||
func TestNewPassDecrypter(t *testing.T) { |
||||
// setup
|
||||
var ( |
||||
testDir = filepath.Join(baseTestDir, t.Name()) |
||||
existPassFile = filepath.Join(testDir, testKeys[0].publicKey+passExt) |
||||
emptyPassFile = filepath.Join(testDir, testKeys[1].publicKey+passExt) |
||||
) |
||||
if err := writeFile(existPassFile, testKeys[0].passphrase); err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
|
||||
tests := []struct { |
||||
config passDecrypterConfig |
||||
expErr error |
||||
providerTypes []passProvider |
||||
}{ |
||||
{ |
||||
config: passDecrypterConfig{passSrcType: PassSrcNil}, |
||||
expErr: errors.New("unknown PassSrcType"), |
||||
}, |
||||
{ |
||||
config: passDecrypterConfig{ |
||||
passSrcType: PassSrcFile, |
||||
passFile: &emptyPassFile, |
||||
}, |
||||
expErr: errors.New("no such file or directory"), |
||||
}, |
||||
{ |
||||
config: passDecrypterConfig{ |
||||
passSrcType: PassSrcFile, |
||||
passFile: &existPassFile, |
||||
}, |
||||
expErr: nil, |
||||
providerTypes: []passProvider{ |
||||
&staticPassProvider{}, |
||||
}, |
||||
}, |
||||
{ |
||||
config: passDecrypterConfig{passSrcType: PassSrcPrompt}, |
||||
expErr: nil, |
||||
providerTypes: []passProvider{ |
||||
&promptPassProvider{}, |
||||
}, |
||||
}, |
||||
{ |
||||
config: passDecrypterConfig{ |
||||
passSrcType: PassSrcPrompt, |
||||
persistPassphrase: true, |
||||
}, |
||||
expErr: nil, |
||||
providerTypes: []passProvider{ |
||||
&promptPassProvider{}, |
||||
}, |
||||
}, |
||||
{ |
||||
config: passDecrypterConfig{ |
||||
passSrcType: PassSrcAuto, |
||||
}, |
||||
expErr: nil, |
||||
providerTypes: []passProvider{ |
||||
&dynamicPassProvider{}, |
||||
&promptPassProvider{}, |
||||
}, |
||||
}, |
||||
{ |
||||
config: passDecrypterConfig{ |
||||
passSrcType: PassSrcAuto, |
||||
passFile: &existPassFile, |
||||
}, |
||||
expErr: nil, |
||||
providerTypes: []passProvider{ |
||||
&staticPassProvider{}, |
||||
&promptPassProvider{}, |
||||
}, |
||||
}, |
||||
} |
||||
for i, test := range tests { |
||||
decrypter, err := newPassDecrypter(test.config) |
||||
|
||||
if assErr := assertError(err, test.expErr); assErr != nil { |
||||
t.Errorf("Test %v: %v", i, assErr) |
||||
continue |
||||
} |
||||
if err != nil { |
||||
continue |
||||
} |
||||
|
||||
if len(decrypter.pps) != len(test.providerTypes) { |
||||
t.Errorf("Test %v: unexpected provider number %v / %v", |
||||
i, len(decrypter.pps), len(test.providerTypes)) |
||||
continue |
||||
} |
||||
for ppIndex, gotPP := range decrypter.pps { |
||||
gotType := reflect.TypeOf(gotPP).Elem() |
||||
expType := reflect.TypeOf(test.providerTypes[ppIndex]).Elem() |
||||
if gotType != expType { |
||||
t.Errorf("Test %v: %v passProvider unexpected type: %v / %v", |
||||
i, ppIndex, gotType, expType) |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
func TestPassDecrypter_decryptFile(t *testing.T) { |
||||
setTestConsole(newTestConsole()) |
||||
unitTestDir := filepath.Join(baseTestDir, t.Name()) |
||||
tests := []struct { |
||||
setupFunc func(rootDir string) error |
||||
providers []passProvider |
||||
keyFile string |
||||
|
||||
expPublicKey string |
||||
expErr error |
||||
}{ |
||||
{ |
||||
// 1. Two providers, one return err and one return the correct passphrase.
|
||||
setupFunc: func(rootDir string) error { |
||||
keyFile := filepath.Join(rootDir, testKeys[1].publicKey+basicKeyExt) |
||||
return writeFile(keyFile, testKeys[1].keyFileData) |
||||
}, |
||||
providers: []passProvider{ |
||||
&errPassProvider{}, |
||||
makeTestPassProvider(), |
||||
}, |
||||
keyFile: testKeys[1].publicKey + basicKeyExt, |
||||
expPublicKey: testKeys[1].publicKey, |
||||
}, |
||||
{ |
||||
// 2. Only error provider. Return the decryption error
|
||||
setupFunc: func(rootDir string) error { |
||||
keyFile := filepath.Join(rootDir, testKeys[1].publicKey+basicKeyExt) |
||||
return writeFile(keyFile, testKeys[1].keyFileData) |
||||
}, |
||||
providers: []passProvider{&errPassProvider{}}, |
||||
keyFile: testKeys[1].publicKey + basicKeyExt, |
||||
expErr: errors.New("failed to load bls key"), |
||||
}, |
||||
} |
||||
for i, test := range tests { |
||||
tcDir := filepath.Join(unitTestDir, fmt.Sprintf("%v", i)) |
||||
os.RemoveAll(tcDir) |
||||
os.MkdirAll(tcDir, 0700) |
||||
if test.setupFunc != nil { |
||||
if err := test.setupFunc(tcDir); err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
} |
||||
keyFile := filepath.Join(tcDir, test.keyFile) |
||||
|
||||
decrypter := &passDecrypter{pps: test.providers} |
||||
secret, err := decrypter.decryptFile(keyFile) |
||||
|
||||
if assErr := assertError(err, test.expErr); assErr != nil { |
||||
t.Errorf("Test %v: %v", i, assErr) |
||||
} |
||||
if err != nil || test.expErr != nil { |
||||
continue |
||||
} |
||||
gotPub := bls.FromLibBLSPublicKeyUnsafe(secret.GetPublicKey())[:] |
||||
if expPub := common.Hex2Bytes(test.expPublicKey); !bytes.Equal(gotPub, expPub) { |
||||
t.Errorf("Test %v: unexpected public key %v / %v", i, gotPub, expPub) |
||||
} |
||||
} |
||||
} |
||||
|
||||
type testPassProvider struct { |
||||
m map[string]string |
||||
} |
||||
|
||||
func makeTestPassProvider() *testPassProvider { |
||||
return &testPassProvider{ |
||||
m: map[string]string{ |
||||
testKeys[0].publicKey: testKeys[0].passphrase, |
||||
testKeys[1].publicKey: testKeys[1].passphrase, |
||||
}, |
||||
} |
||||
} |
||||
|
||||
func (provider *testPassProvider) getPassphrase(keyFile string) (string, error) { |
||||
basename := filepath.Base(keyFile) |
||||
publicKey := strings.TrimSuffix(basename, filepath.Ext(basename)) |
||||
pass, exist := provider.m[publicKey] |
||||
if !exist { |
||||
return "", errors.New("passphrase not exist") |
||||
} |
||||
return pass, nil |
||||
} |
||||
|
||||
type errPassProvider struct{} |
||||
|
||||
func (provider *errPassProvider) getPassphrase(keyFile string) (string, error) { |
||||
return "", errors.New("error intended") |
||||
} |
||||
|
||||
func TestPromptPassProvider_getPassphrase(t *testing.T) { |
||||
unitTestDir := filepath.Join(baseTestDir, t.Name()) |
||||
tests := []struct { |
||||
setupFunc func(rootDir string) error |
||||
keyFile string |
||||
passphrase string |
||||
enablePersist bool |
||||
extraInput []string |
||||
|
||||
expOutputLen int |
||||
expErr error |
||||
newPassFileContent bool |
||||
passFileExist bool |
||||
}{ |
||||
{ |
||||
setupFunc: nil, |
||||
keyFile: testKeys[1].publicKey + basicKeyExt, |
||||
passphrase: "new key", |
||||
enablePersist: false, |
||||
extraInput: []string{}, |
||||
expOutputLen: 1, // prompt for passphrase
|
||||
passFileExist: false, |
||||
newPassFileContent: false, |
||||
}, |
||||
{ |
||||
// new pass file
|
||||
setupFunc: nil, |
||||
keyFile: testKeys[1].publicKey + basicKeyExt, |
||||
passphrase: "new key", |
||||
enablePersist: true, |
||||
extraInput: []string{}, |
||||
expOutputLen: 1, // prompt for passphrase
|
||||
passFileExist: true, |
||||
newPassFileContent: true, |
||||
}, |
||||
{ |
||||
// exist pass file, do not overwrite
|
||||
setupFunc: func(rootDir string) error { |
||||
passFile := filepath.Join(rootDir, testKeys[1].publicKey+passExt) |
||||
return writeFile(passFile, "old key") |
||||
}, |
||||
keyFile: testKeys[1].publicKey + basicKeyExt, |
||||
passphrase: "new key", |
||||
enablePersist: true, |
||||
extraInput: []string{"n"}, |
||||
expOutputLen: 2, // prompt for passphrase and ask for overwrite
|
||||
passFileExist: true, |
||||
newPassFileContent: false, |
||||
}, |
||||
{ |
||||
// exist pass file, do overwrite
|
||||
setupFunc: func(rootDir string) error { |
||||
passFile := filepath.Join(rootDir, testKeys[1].publicKey+passExt) |
||||
return writeFile(passFile, "old key") |
||||
}, |
||||
keyFile: testKeys[1].publicKey + basicKeyExt, |
||||
passphrase: "new key", |
||||
enablePersist: true, |
||||
extraInput: []string{"y"}, |
||||
expOutputLen: 2, // prompt for passphrase and ask for overwrite
|
||||
passFileExist: true, |
||||
newPassFileContent: true, |
||||
}, |
||||
} |
||||
|
||||
for i, test := range tests { |
||||
tc := newTestConsole() |
||||
setTestConsole(tc) |
||||
tcDir := filepath.Join(unitTestDir, fmt.Sprintf("%v", i)) |
||||
os.RemoveAll(tcDir) |
||||
os.MkdirAll(tcDir, 0700) |
||||
if test.setupFunc != nil { |
||||
if err := test.setupFunc(tcDir); err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
} |
||||
tc.In <- test.passphrase |
||||
for _, in := range test.extraInput { |
||||
tc.In <- in |
||||
} |
||||
|
||||
ppd := &promptPassProvider{enablePersist: test.enablePersist} |
||||
keyFile := filepath.Join(tcDir, test.keyFile) |
||||
passphrase, err := ppd.getPassphrase(keyFile) |
||||
|
||||
if assErr := assertError(err, test.expErr); assErr != nil { |
||||
t.Errorf("Test %v: %v", i, assErr) |
||||
continue |
||||
} |
||||
if passphrase != test.passphrase { |
||||
t.Errorf("Test %v: got unexpected passphrase: %v / %v", i, passphrase, test.passphrase) |
||||
continue |
||||
} |
||||
for index := 0; index != test.expOutputLen; index++ { |
||||
<-tc.Out |
||||
} |
||||
if isClean, msg := tc.checkClean(); !isClean { |
||||
t.Errorf("Test %v: console not clean: %v", i, msg) |
||||
continue |
||||
} |
||||
passFile := keyFileToPassFileFull(keyFile) |
||||
if !test.passFileExist { |
||||
if _, err := os.Stat(passFile); !os.IsNotExist(err) { |
||||
t.Errorf("Test %v: pass file exist %v", i, passFile) |
||||
} |
||||
} else { |
||||
b, err := ioutil.ReadFile(passFile) |
||||
if err != nil { |
||||
t.Error(err) |
||||
continue |
||||
} |
||||
if test.newPassFileContent && string(b) != test.passphrase { |
||||
t.Errorf("Test %v: unexpected passphrase from persist file: %v/ %v", |
||||
i, string(b), test.passphrase) |
||||
} |
||||
if !test.newPassFileContent && string(b) == test.passphrase { |
||||
t.Errorf("Test %v: passphrase content has changed", i) |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
func TestDynamicPassProvider_getPassPhrase(t *testing.T) { |
||||
unitTestDir := filepath.Join(baseTestDir, t.Name()) |
||||
|
||||
tests := []struct { |
||||
setupFunc func(rootDir string) error |
||||
keyFile string |
||||
expPass string |
||||
expErr error |
||||
}{ |
||||
{ |
||||
setupFunc: func(rootDir string) error { |
||||
passFile := filepath.Join(rootDir, testKeys[1].publicKey+passExt) |
||||
return writeFile(passFile, "passphrase\n") |
||||
}, |
||||
keyFile: testKeys[1].publicKey + basicKeyExt, |
||||
expPass: "passphrase", |
||||
}, |
||||
{ |
||||
keyFile: testKeys[1].publicKey + basicKeyExt, |
||||
expErr: errors.New("no such file"), |
||||
}, |
||||
} |
||||
for i, test := range tests { |
||||
tcDir := filepath.Join(unitTestDir, fmt.Sprintf("%v", i)) |
||||
os.RemoveAll(tcDir) |
||||
os.MkdirAll(tcDir, 0700) |
||||
|
||||
if test.setupFunc != nil { |
||||
if err := test.setupFunc(tcDir); err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
} |
||||
provider := &dynamicPassProvider{} |
||||
keyFile := filepath.Join(tcDir, test.keyFile) |
||||
got, err := provider.getPassphrase(keyFile) |
||||
|
||||
if assErr := assertError(err, test.expErr); assErr != nil { |
||||
t.Errorf("Test %v: %v", i, assErr) |
||||
continue |
||||
} |
||||
if err != nil { |
||||
continue |
||||
} |
||||
if got != test.expPass { |
||||
t.Errorf("Test %v: unexpected passphrase: %v / %v", i, got, test.expPass) |
||||
} |
||||
} |
||||
} |
@ -0,0 +1,92 @@ |
||||
package blsgen |
||||
|
||||
import ( |
||||
"fmt" |
||||
"os" |
||||
"path/filepath" |
||||
"strings" |
||||
|
||||
bls_core "github.com/harmony-one/bls/ffi/go/bls" |
||||
) |
||||
|
||||
func loadBasicKeyWithProvider(blsKeyFile string, pp passProvider) (*bls_core.SecretKey, error) { |
||||
pass, err := pp.getPassphrase(blsKeyFile) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
secretKey, err := LoadBLSKeyWithPassPhrase(blsKeyFile, pass) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
return secretKey, nil |
||||
} |
||||
|
||||
func checkIsFile(path string) error { |
||||
info, err := os.Stat(path) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
if info.IsDir() { |
||||
return fmt.Errorf("%v is directory", path) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func checkIsDir(path string) error { |
||||
info, err := os.Stat(path) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
if !info.IsDir() { |
||||
return fmt.Errorf("%v is a file", path) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func checkIsPassFile(path string) error { |
||||
if err := checkIsFile(path); err != nil { |
||||
return err |
||||
} |
||||
if filepath.Ext(path) != passExt { |
||||
return fmt.Errorf("pass file %v should have extension .pass", path) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func keyFileToPassFileFull(keyFile string) string { |
||||
return strings.TrimSuffix(keyFile, basicKeyExt) + passExt |
||||
} |
||||
|
||||
func promptGetPassword(prompt string) (string, error) { |
||||
if !strings.HasSuffix(prompt, ":") { |
||||
prompt += ":" |
||||
} |
||||
console.print(prompt) |
||||
return console.readPassword() |
||||
} |
||||
|
||||
const yesNoPrompt = "[y/n]: " |
||||
|
||||
func promptYesNo(prompt string) (bool, error) { |
||||
if !strings.HasSuffix(prompt, yesNoPrompt) { |
||||
prompt = prompt + yesNoPrompt |
||||
} |
||||
for { |
||||
console.print(prompt) |
||||
response, err := console.readln() |
||||
if err != nil { |
||||
return false, err |
||||
} |
||||
response = strings.TrimSpace(strings.ToLower(response)) |
||||
|
||||
if response == "y" || response == "yes" { |
||||
return true, nil |
||||
} else if response == "n" || response == "no" { |
||||
return false, nil |
||||
} |
||||
} |
||||
} |
||||
|
||||
func stringIsSet(val *string) bool { |
||||
return val != nil && *val != "" |
||||
} |
@ -0,0 +1,152 @@ |
||||
package blsgen |
||||
|
||||
import ( |
||||
"fmt" |
||||
"io/ioutil" |
||||
"os" |
||||
"path/filepath" |
||||
"strings" |
||||
"testing" |
||||
) |
||||
|
||||
const testPrompt = yesNoPrompt |
||||
|
||||
func init() { |
||||
// Move the test data to temp directory
|
||||
os.RemoveAll(baseTestDir) |
||||
os.MkdirAll(baseTestDir, 0777) |
||||
} |
||||
|
||||
var baseTestDir = filepath.Join(".testdata") |
||||
|
||||
type testKey struct { |
||||
publicKey string |
||||
privateKey string |
||||
passphrase string |
||||
keyFileData string |
||||
} |
||||
|
||||
// testKeys are keys with valid passphrase and valid .pass file
|
||||
var testKeys = []testKey{ |
||||
{ |
||||
// key with empty passphrase
|
||||
publicKey: "0e969f8b302cf7648bc39652ca7a279a8562b72933a3f7cddac2252583280c7c3495c9ae854f00f6dd19c32fc5a17500", |
||||
privateKey: "78c88c331195591b396e3205830071901a7a79e14fd0ede7f06bfb4c5e9f3473", |
||||
passphrase: "", |
||||
keyFileData: "1d97f32175d8875f251e15805fd08f0cda794d827cb02d2de7b10d10f36f951d68347bef1e7a3018bd865c6966219cd9c4d20b055c50f8e09a6a3a1666b7c112450f643cc3c175f541fae75da8a843d47993fe89ec85788fd6ea2e98", |
||||
}, |
||||
{ |
||||
// key with non empty passphrase
|
||||
publicKey: "152beed46d7a0002ef0f960946008887eedd4775bdf2ed238809aa74e20d31fdca267443615cc6f4ede49d58911ee083", |
||||
privateKey: "c20fa8de733d08e27e3101436d41f6a3207b8bedad7525c6e91a77ae2a49cf56", |
||||
passphrase: "harmony", |
||||
keyFileData: "194a2d68c37f037f36b28a560402d64ab007f949313b63d9a08f5adb55a061681c70d9119df2d2cdcae5da6e484550c03bad63aae7c1332a3647ce633999ac4ddbb4a40e213c7e88e604784fef40da9d2f28b392c9fb2462f5e51e9c", |
||||
}, |
||||
} |
||||
|
||||
func writeFile(file string, data string) error { |
||||
dir := filepath.Dir(file) |
||||
os.MkdirAll(dir, 0700) |
||||
return ioutil.WriteFile(file, []byte(data), 0600) |
||||
} |
||||
|
||||
func TestPromptYesNo(t *testing.T) { |
||||
tests := []struct { |
||||
inputs []string |
||||
lenOutputs int |
||||
expRes bool |
||||
expErr error |
||||
}{ |
||||
{ |
||||
inputs: []string{"yes"}, |
||||
lenOutputs: 1, |
||||
expRes: true, |
||||
}, |
||||
{ |
||||
inputs: []string{"YES\n"}, |
||||
lenOutputs: 1, |
||||
expRes: true, |
||||
}, |
||||
{ |
||||
inputs: []string{"y"}, |
||||
lenOutputs: 1, |
||||
expRes: true, |
||||
}, |
||||
{ |
||||
inputs: []string{"Y"}, |
||||
lenOutputs: 1, |
||||
expRes: true, |
||||
}, |
||||
{ |
||||
inputs: []string{"\tY"}, |
||||
lenOutputs: 1, |
||||
expRes: true, |
||||
}, |
||||
{ |
||||
inputs: []string{"No"}, |
||||
lenOutputs: 1, |
||||
expRes: false, |
||||
}, |
||||
{ |
||||
inputs: []string{"\tn"}, |
||||
lenOutputs: 1, |
||||
expRes: false, |
||||
}, |
||||
{ |
||||
inputs: []string{"invalid input", "y"}, |
||||
lenOutputs: 2, |
||||
expRes: true, |
||||
}, |
||||
} |
||||
for i, test := range tests { |
||||
tc := newTestConsole() |
||||
setTestConsole(tc) |
||||
for _, input := range test.inputs { |
||||
tc.In <- input |
||||
} |
||||
|
||||
got, err := promptYesNo(testPrompt) |
||||
if assErr := assertError(err, test.expErr); assErr != nil { |
||||
t.Errorf("Test %v: %v", i, assErr) |
||||
} else if assErr != nil { |
||||
continue |
||||
} |
||||
|
||||
// check results
|
||||
if got != test.expRes { |
||||
t.Errorf("Test %v: result unexpected %v / %v", i, got, test.expRes) |
||||
} |
||||
gotOutputs := drainCh(tc.Out) |
||||
if len(gotOutputs) != test.lenOutputs { |
||||
t.Errorf("unexpected output size: %v / %v", len(gotOutputs), test.lenOutputs) |
||||
} |
||||
if clean, msg := tc.checkClean(); !clean { |
||||
t.Errorf("Test %v: console unclean with message [%v]", i, msg) |
||||
} |
||||
} |
||||
} |
||||
|
||||
func drainCh(c chan string) []string { |
||||
var res []string |
||||
for { |
||||
select { |
||||
case gotOut := <-c: |
||||
res = append(res, gotOut) |
||||
default: |
||||
return res |
||||
} |
||||
} |
||||
} |
||||
|
||||
func assertError(got, expect error) error { |
||||
if (got == nil) != (expect == nil) { |
||||
return fmt.Errorf("unexpected error [%v] / [%v]", got, expect) |
||||
} |
||||
if (got == nil) || (expect == nil) { |
||||
return nil |
||||
} |
||||
if !strings.Contains(got.Error(), expect.Error()) { |
||||
return fmt.Errorf("unexpected error [%v] / [%v]", got, expect) |
||||
} |
||||
return nil |
||||
} |
Loading…
Reference in new issue