[node.sh] added a lot of test caess

pull/3219/head
Jacky Wang 4 years ago
parent 60ae574001
commit e519e0e62a
No known key found for this signature in database
GPG Key ID: 1085CE5F4FF5842C
  1. 13
      cmd/harmony/blsloader/kms.go
  2. 151
      cmd/harmony/blsloader/kms_test.go
  3. 88
      cmd/harmony/blsloader/loader.go
  4. 2
      cmd/harmony/blsloader/passphrase.go
  5. 51
      cmd/harmony/blsloader/passphrase_test.go

@ -146,7 +146,6 @@ func (cfg AwsConfig) toAws() *aws.Config {
// promptACProvider - provide the config field from prompt with time out
type awsConfigProvider interface {
getAwsConfig() (*AwsConfig, error)
String() string
}
// sharedACProvider returns nil for getAwsConfig to use shared aws configurations
@ -160,10 +159,6 @@ func (provider *sharedACProvider) getAwsConfig() (*AwsConfig, error) {
return nil, nil
}
func (provider *sharedACProvider) String() string {
return "shared aws config"
}
// fileACProvider get aws config through a customized json file
type fileACProvider struct {
file string
@ -185,10 +180,6 @@ func (provider *fileACProvider) getAwsConfig() (*AwsConfig, error) {
return &cfg, nil
}
func (provider *fileACProvider) String() string {
return fmt.Sprintf("file %v", provider.file)
}
// promptACProvider provide a user interactive console for AWS config.
// Four fields are asked:
// 1. AccessKey 2. SecretKey 3. Region
@ -255,10 +246,6 @@ func (provider *promptACProvider) threadedPrompt(hint string) (string, error) {
return console.readPassword()
}
func (provider *promptACProvider) String() string {
return "prompt"
}
func kmsClientWithConfig(config *AwsConfig) (*kms.KMS, error) {
if config == nil {
return getSharedKMSClient()

@ -0,0 +1,151 @@
package blsloader
import (
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"reflect"
"testing"
"time"
)
var TestAwsConfig = AwsConfig{
AccessKey: "access key",
SecretKey: "secret key",
Region: "region",
}
//func TestNewKmsDecrypter(t *testing.T) {
// tests := []struct {
//
// }
//}
func TestPromptACProvider_getAwsConfig(t *testing.T) {
tc := newTestConsole()
setTestConsole(tc)
for _, input := range []string{
TestAwsConfig.AccessKey,
TestAwsConfig.SecretKey,
TestAwsConfig.Region,
} {
tc.In <- input
}
provider := newPromptACProvider(1 * time.Second)
got, err := provider.getAwsConfig()
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(*got, TestAwsConfig) {
t.Errorf("unexpected result %+v / %+v", got, TestAwsConfig)
}
}
func TestPromptACProvider_prompt(t *testing.T) {
tests := []struct {
delay, timeout time.Duration
expErr error
}{
{
delay: 100 * time.Microsecond,
timeout: 1000 * time.Microsecond,
expErr: nil,
},
{
delay: 2000 * time.Microsecond,
timeout: 1000 * time.Microsecond,
expErr: errors.New("timed out"),
},
}
for i, test := range tests {
tc := newTestConsole()
setTestConsole(tc)
testInput := "test"
go func() {
<-time.After(test.delay)
tc.In <- testInput
}()
provider := newPromptACProvider(test.timeout)
got, err := provider.prompt("test ask string")
if assErr := assertError(err, test.expErr); assErr != nil {
t.Errorf("Test %v: %v", i, assErr)
continue
}
if err != nil {
continue
}
if got != testInput {
t.Errorf("Test %v: unexpected prompt result: %v / %v", i, got, testInput)
}
}
}
func TestFileACProvider_getAwsConfig(t *testing.T) {
jsonBytes, err := json.Marshal(TestAwsConfig)
if err != nil {
t.Fatal(err)
}
unitTestDir := filepath.Join(baseTestDir, t.Name())
tests := []struct {
setupFunc func(rootDir string) error
expConfig AwsConfig
jsonFile string
expErr error
}{
{
// positive
setupFunc: func(rootDir string) error {
jsonFile := filepath.Join(rootDir, "valid.json")
return writeFile(jsonFile, string(jsonBytes))
},
jsonFile: "valid.json",
expConfig: TestAwsConfig,
},
{
// no such file
setupFunc: nil,
jsonFile: "empty.json",
expErr: errors.New("no such file"),
},
{
// invalid json string
setupFunc: func(rootDir string) error {
jsonFile := filepath.Join(rootDir, "invalid.json")
return writeFile(jsonFile, string(jsonBytes[:len(jsonBytes)-2]))
},
jsonFile: "invalid.json",
expErr: errors.New("unexpected end of JSON input"),
},
}
for i, test := range tests {
tcDir := filepath.Join(unitTestDir, fmt.Sprintf("%v", i))
os.RemoveAll(tcDir)
os.MkdirAll(tcDir, 0700)
if test.setupFunc != nil {
if err := test.setupFunc(tcDir); err != nil {
t.Fatal(err)
}
}
provider := newFileACProvider(filepath.Join(tcDir, test.jsonFile))
got, err := provider.getAwsConfig()
if assErr := assertError(err, test.expErr); assErr != nil {
t.Errorf("Test %v: %v", i, assErr)
continue
}
if err != nil {
continue
}
if got == nil || !reflect.DeepEqual(*got, test.expConfig) {
t.Errorf("Test %v: unexpected AwsConfig: %+v / %+v", i,
got, test.expConfig)
}
}
}

@ -20,48 +20,6 @@ func LoadKeys(cfg Config) (multibls.PrivateKeys, error) {
return helper.loadKeys()
}
// keyDecrypter is the interface to decrypt the bls key file. Currently, two
// implementations are supported:
// passDecrypter - decrypt with passphrase
// kmsDecrypter - decrypt with aws kms service
type keyDecrypter interface {
extension() string
decryptFile(keyFile string) (*bls_core.SecretKey, error)
}
func getKeyDecrypters(cfg Config) ([]keyDecrypter, error) {
var decrypters []keyDecrypter
if cfg.PassSrcType != PassSrcNil {
pd, err := newPassDecrypter(cfg.getPassProviderConfig())
if err != nil {
return nil, err
}
decrypters = append(decrypters, pd)
}
if cfg.AwsCfgSrcType != AwsCfgSrcNil {
kd, err := newKmsDecrypter(cfg.getKmsProviderConfig())
if err != nil {
return nil, err
}
decrypters = append(decrypters, kd)
}
if len(decrypters) == 0 {
return nil, fmt.Errorf("must provide at least one bls key decryption")
}
return decrypters, nil
}
func getHelper(cfg Config, decrypters []keyDecrypter) (loadHelper, error) {
switch {
case len(cfg.multiBlsKeys) != 0:
return newMultiKeyLoader(cfg.multiBlsKeys, decrypters)
case stringIsSet(cfg.BlsDir):
return newBlsDirLoader(*cfg.BlsDir, decrypters)
default:
return nil, errors.New("either multiBlsKeys or BlsDir must be set")
}
}
// Loader is the structure to load bls keys.
type Config struct {
// source for bls key loading. At least one of the multiBlsKeys and BlsDir
@ -75,10 +33,10 @@ type Config struct {
// Passphrase related settings. Used for passphrase encrypted key files.
//
// PassSrcType defines the source to get passphrase. Three source types are available
// PassSrcNil - do not use passphrase decryption
// PassSrcFile - get passphrase from a .pass file
// PassSrcPrompt - get passphrase from prompt
// PassSrcAuto - try to unlock with .pass file. If not success, ask user with prompt
// Value is default to PassSrcAuto.
PassSrcType PassSrcType
// PassFile specifies the .pass file to be used when loading passphrase from file.
// If not set, default to the .pass file in the same directory as the key file.
@ -92,10 +50,10 @@ type Config struct {
// Used for KMS encrypted passphrase files.
//
// AwsCfgSrcType defines the source to get aws config. Three types available:
// AwsCfgSrcNil - do not use Aws KMS decryption service.
// AwsCfgSrcFile - get AWS config through a json file. See AwsConfig for content fields.
// AwsCfgSrcPrompt - get AWS config through prompt.
// AwsCfgSrcShared - Use the default AWS config settings (from env and $HOME/.aws/config)
// Default to AwsCfgSrcShared.
AwsCfgSrcType AwsCfgSrcType
// AwsConfigFile set the json file to load aws config.
AwsConfigFile *string
@ -115,3 +73,45 @@ func (cfg *Config) getKmsProviderConfig() kmsDecrypterConfig {
awsConfigFile: cfg.AwsConfigFile,
}
}
// keyDecrypter is the interface to decrypt the bls key file. Currently, two
// implementations are supported:
// passDecrypter - decrypt with passphrase
// kmsDecrypter - decrypt with aws kms service
type keyDecrypter interface {
extension() string
decryptFile(keyFile string) (*bls_core.SecretKey, error)
}
func getKeyDecrypters(cfg Config) ([]keyDecrypter, error) {
var decrypters []keyDecrypter
if cfg.PassSrcType != PassSrcNil {
pd, err := newPassDecrypter(cfg.getPassProviderConfig())
if err != nil {
return nil, err
}
decrypters = append(decrypters, pd)
}
if cfg.AwsCfgSrcType != AwsCfgSrcNil {
kd, err := newKmsDecrypter(cfg.getKmsProviderConfig())
if err != nil {
return nil, err
}
decrypters = append(decrypters, kd)
}
if len(decrypters) == 0 {
return nil, fmt.Errorf("must provide at least one bls key decryption")
}
return decrypters, nil
}
func getHelper(cfg Config, decrypters []keyDecrypter) (loadHelper, error) {
switch {
case len(cfg.multiBlsKeys) != 0:
return newMultiKeyLoader(cfg.multiBlsKeys, decrypters)
case stringIsSet(cfg.BlsDir):
return newBlsDirLoader(*cfg.BlsDir, decrypters)
default:
return nil, errors.New("either multiBlsKeys or BlsDir must be set")
}
}

@ -211,7 +211,7 @@ func (provider *dynamicPassProvider) getPassphrase(keyFile string) (string, erro
func readPassFromFile(file string) (string, error) {
f, err := os.Open(file)
if err != nil {
return "", fmt.Errorf("cannot open passphrase file: %v", err)
return "", err
}
defer f.Close()

@ -160,7 +160,7 @@ func TestPromptPassProvider_getPassphrase(t *testing.T) {
newPassFileContent: true,
},
{
// exist pass file, not overwrite
// exist pass file, do not overwrite
setupFunc: func(rootDir string) error {
passFile := filepath.Join(rootDir, testKeys[1].publicKey+passExt)
return writeFile(passFile, "old key")
@ -245,3 +245,52 @@ func TestPromptPassProvider_getPassphrase(t *testing.T) {
}
}
}
func TestDynamicPassProvider_getPassPhrase(t *testing.T) {
unitTestDir := filepath.Join(baseTestDir, t.Name())
tests := []struct {
setupFunc func(rootDir string) error
keyFile string
expPass string
expErr error
}{
{
setupFunc: func(rootDir string) error {
passFile := filepath.Join(rootDir, testKeys[1].publicKey+passExt)
return writeFile(passFile, "passphrase\n")
},
keyFile: testKeys[1].publicKey + basicKeyExt,
expPass: "passphrase",
},
{
keyFile: testKeys[1].publicKey + basicKeyExt,
expErr: errors.New("no such file"),
},
}
for i, test := range tests {
tcDir := filepath.Join(unitTestDir, fmt.Sprintf("%v", i))
os.RemoveAll(tcDir)
os.MkdirAll(tcDir, 0700)
if test.setupFunc != nil {
if err := test.setupFunc(tcDir); err != nil {
t.Fatal(err)
}
}
provider := &dynamicPassProvider{}
keyFile := filepath.Join(tcDir, test.keyFile)
got, err := provider.getPassphrase(keyFile)
if assErr := assertError(err, test.expErr); assErr != nil {
t.Errorf("Test %v: %v", i, assErr)
continue
}
if err != nil {
continue
}
if got != test.expPass {
t.Errorf("Test %v: unexpected passphrase: %v / %v", i, got, test.expPass)
}
}
}

Loading…
Cancel
Save