[node.sh] refactor kmsClientProvider to kmsProvider and awsConfigGetter

pull/3219/head
Jacky Wang 4 years ago
parent cc600f9879
commit b61953aeff
No known key found for this signature in database
GPG Key ID: 1085CE5F4FF5842C
  1. 36
      cmd/harmony/blsloader/helper.go
  2. 150
      cmd/harmony/blsloader/kmsProvider.go
  3. 58
      cmd/harmony/blsloader/loader.go
  4. 4
      cmd/harmony/blsloader/passProvider.go
  5. 2
      cmd/harmony/blsloader/utils.go

@ -86,20 +86,8 @@ func (loader *kmsSingleBlsLoader) loadKeys() (multibls.PrivateKeys, error) {
return multibls.GetPrivateKeys(secretKey), nil return multibls.GetPrivateKeys(secretKey), nil
} }
func (loader *kmsSingleBlsLoader) getKmsClientProvider() (kmsClientProvider, error) { func (loader *kmsSingleBlsLoader) getKmsClientProvider() (kmsProvider, error) {
switch loader.awsCfgSrcType { return newLazyKmsProvider(loader.kmsProviderConfig)
case AwsCfgSrcFile:
if stringIsSet(loader.awsConfigFile) {
return newFileKmsProvider(*loader.awsConfigFile), nil
}
return newSharedKmsProvider(), nil
case AwsCfgSrcPrompt:
return newPromptKmsProvider(defKmsPromptTimeout), nil
case AwsCfgSrcShared:
return newSharedKmsProvider(), nil
default:
return nil, errors.New("unknown aws config source type")
}
} }
// blsDirLoader is the helper structure for loading bls keys in a directory // blsDirLoader is the helper structure for loading bls keys in a directory
@ -111,7 +99,7 @@ type blsDirLoader struct {
// providers in process // providers in process
pps []passProvider pps []passProvider
kcp kmsClientProvider kcp kmsProvider
// result field // result field
secretKeys []*bls_core.SecretKey secretKeys []*bls_core.SecretKey
} }
@ -159,20 +147,8 @@ func (loader *blsDirLoader) getPromptPassProvider() passProvider {
return provider return provider
} }
func (loader *blsDirLoader) getKmsClientProvider() (kmsClientProvider, error) { func (loader *blsDirLoader) getKmsClientProvider() (kmsProvider, error) {
switch loader.awsCfgSrcType { return newLazyKmsProvider(loader.kmsProviderConfig)
case AwsCfgSrcFile:
if stringIsSet(loader.awsConfigFile) {
return newFileKmsProvider(*loader.awsConfigFile), nil
}
return newSharedKmsProvider(), nil
case AwsCfgSrcPrompt:
return newPromptKmsProvider(defKmsPromptTimeout), nil
case AwsCfgSrcShared:
return newSharedKmsProvider(), nil
default:
return nil, errors.New("unknown aws config source type")
}
} }
func (loader *blsDirLoader) loadKeyFiles() (multibls.PrivateKeys, error) { func (loader *blsDirLoader) loadKeyFiles() (multibls.PrivateKeys, error) {
@ -190,7 +166,7 @@ func (loader *blsDirLoader) processFileWalk(path string, info os.FileInfo, err e
// unexpected error, return the error and break the file walk loop // unexpected error, return the error and break the file walk loop
return err return err
} }
// expected error. Skipping these files // errors to be skipped. Skipping these files
skipStr := fmt.Sprintf("Skipping [%s]: %v\n", path, err) skipStr := fmt.Sprintf("Skipping [%s]: %v\n", path, err)
console.println(skipStr) console.println(skipStr)
return nil return nil

@ -37,49 +37,61 @@ type kmsProviderConfig struct {
awsConfigFile *string awsConfigFile *string
} }
func (cfg kmsProviderConfig) validate() error { func (config kmsProviderConfig) validate() error {
if !cfg.awsCfgSrcType.isValid() { if !config.awsCfgSrcType.isValid() {
return errors.New("unknown AwsCfgSrcType") return errors.New("unknown AwsCfgSrcType")
} }
if cfg.awsCfgSrcType == AwsCfgSrcFile { if config.awsCfgSrcType == AwsCfgSrcFile {
if !stringIsSet(cfg.awsConfigFile) { if !stringIsSet(config.awsConfigFile) {
return errors.New("config field AwsConfig file must set for AwsCfgSrcFile") return errors.New("config field AwsConfig file must set for AwsCfgSrcFile")
} }
if !isFile(*cfg.awsConfigFile) { if !isFile(*config.awsConfigFile) {
return fmt.Errorf("aws config file not exist %v", *cfg.awsConfigFile) return fmt.Errorf("aws config file not exist %v", *config.awsConfigFile)
} }
} }
return nil return nil
} }
// kmsClientProvider provides the kms client. Implemented by // kmsProvider provide the aws kms client
// baseKMSProvider - abstract implementation type kmsProvider interface {
// sharedKMSProvider - provide the client with default .aws folder
// fileKMSProvider - provide the aws config with a json file
// promptKMSProvider - provide the config field from prompt with time out
type kmsClientProvider interface {
// getKMSClient returns the KMSClient of the kmsClientProvider with lazy loading.
getKMSClient() (*kms.KMS, error) getKMSClient() (*kms.KMS, error)
// toStr return the string presentation of kmsClientProvider
toStr() string
} }
type getAwsCfgFunc func() (*AwsConfig, error) // lazyKmsProvider provide the kms client with singleton lazy initialization with config get
// from awsConfigGetter for aws credential and regions loading.
type lazyKmsProvider struct {
acGetter awsConfigGetter
// baseKMSProvider provide the kms client with singleton initialization through
// function getConfig for aws credential and regions loading.
type baseKMSProvider struct {
client *kms.KMS client *kms.KMS
err error err error
once sync.Once once sync.Once
}
getAWSConfig getAwsCfgFunc // newLazyKmsProvider creates a kmsProvider with the given config
func newLazyKmsProvider(config kmsProviderConfig) (*lazyKmsProvider, error) {
var acg awsConfigGetter
switch config.awsCfgSrcType {
case AwsCfgSrcFile:
if stringIsSet(config.awsConfigFile) {
acg = newFileACGetter(*config.awsConfigFile)
} else {
acg = newSharedAwsConfigGetter()
}
case AwsCfgSrcPrompt:
acg = newPromptACGetter(defKmsPromptTimeout)
case AwsCfgSrcShared:
acg = newSharedAwsConfigGetter()
default:
return nil, errors.New("unknown aws config source type")
}
return &lazyKmsProvider{
acGetter: acg,
}, nil
} }
func (provider *baseKMSProvider) getKMSClient() (*kms.KMS, error) { func (provider *lazyKmsProvider) getKMSClient() (*kms.KMS, error) {
provider.once.Do(func() { provider.once.Do(func() {
cfg, err := provider.getAWSConfig() cfg, err := provider.acGetter.getAwsConfig()
if err != nil { if err != nil {
provider.err = err provider.err = err
return return
@ -92,94 +104,80 @@ func (provider *baseKMSProvider) getKMSClient() (*kms.KMS, error) {
return provider.client, nil return provider.client, nil
} }
func (provider *baseKMSProvider) toStr() string { // awsConfigGetter provides the aws config. Implemented by
return "not implemented" // sharedACGetter - provide the nil to use shared AWS configuration
// fileACGetter - provide the aws config with a json file
// promptACGetter - provide the config field from prompt with time out
type awsConfigGetter interface {
getAwsConfig() (*AwsConfig, error)
String() string
} }
// sharedKMSProvider provide the kms session with the default aws config // sharedACGetter returns nil for getAwsConfig to use shared aws configurations
// locates in directory $HOME/.aws/config type sharedACGetter struct{}
type sharedKMSProvider struct {
baseKMSProvider
}
func newSharedKmsProvider() *sharedKMSProvider { func newSharedAwsConfigGetter() *sharedACGetter {
provider := &sharedKMSProvider{baseKMSProvider{}} return &sharedACGetter{}
provider.baseKMSProvider.getAWSConfig = provider.getAWSConfig
return provider
} }
// TODO(Jacky): set getAwsConfig into a function, not bind with structure func (getter *sharedACGetter) getAwsConfig() (*AwsConfig, error) {
func (provider *sharedKMSProvider) getAWSConfig() (*AwsConfig, error) {
return nil, nil return nil, nil
} }
func (provider *sharedKMSProvider) toStr() string { func (getter *sharedACGetter) String() string {
return "shared aws config" return "shared aws config"
} }
// fileKMSProvider provide the kms session from a file with json data of structure // fileACGetter get aws config through a customized json file
// AwsConfig type fileACGetter struct {
type fileKMSProvider struct {
baseKMSProvider
file string file string
} }
func newFileKmsProvider(file string) *fileKMSProvider { func newFileACGetter(file string) *fileACGetter {
provider := &fileKMSProvider{ return &fileACGetter{file}
baseKMSProvider: baseKMSProvider{},
file: file,
}
provider.baseKMSProvider.getAWSConfig = provider.getAWSConfig
return provider
} }
func (provider *fileKMSProvider) getAWSConfig() (*AwsConfig, error) { func (getter *fileACGetter) getAwsConfig() (*AwsConfig, error) {
b, err := ioutil.ReadFile(provider.file) b, err := ioutil.ReadFile(getter.file)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var cfg *AwsConfig var cfg AwsConfig
if err := json.Unmarshal(b, cfg); err != nil { if err := json.Unmarshal(b, &cfg); err != nil {
return nil, err return nil, err
} }
return cfg, nil return &cfg, nil
} }
func (provider *fileKMSProvider) toStr() string { func (getter *fileACGetter) String() string {
return fmt.Sprintf("file %v", provider.file) return fmt.Sprintf("file %v", getter.file)
} }
// promptKMSProvider provide a user interactive console for AWS config. // promptACGetter provide a user interactive console for AWS config.
// Three fields are asked: // Four fields are asked:
// 1. AccessKey 2. SecretKey 3. Region // 1. AccessKey 2. SecretKey 3. Region
// Each field is asked with a timeout mechanism. // Each field is asked with a timeout mechanism.
type promptKMSProvider struct { type promptACGetter struct {
baseKMSProvider
timeout time.Duration timeout time.Duration
} }
func newPromptKmsProvider(timeout time.Duration) *promptKMSProvider { func newPromptACGetter(timeout time.Duration) *promptACGetter {
provider := &promptKMSProvider{ return &promptACGetter{
baseKMSProvider: baseKMSProvider{}, timeout: timeout,
timeout: timeout,
} }
provider.baseKMSProvider.getAWSConfig = provider.getAWSConfig
return provider
} }
func (provider *promptKMSProvider) getAWSConfig() (*AwsConfig, error) { func (getter *promptACGetter) getAwsConfig() (*AwsConfig, error) {
console.println("Please provide AWS configurations for KMS encoded BLS keys:") console.println("Please provide AWS configurations for KMS encoded BLS keys:")
accessKey, err := provider.prompt(" AccessKey:") accessKey, err := getter.prompt(" AccessKey:")
if err != nil { if err != nil {
return nil, fmt.Errorf("cannot get aws access key: %v", err) return nil, fmt.Errorf("cannot get aws access key: %v", err)
} }
secretKey, err := provider.prompt(" SecretKey:") secretKey, err := getter.prompt(" SecretKey:")
if err != nil { if err != nil {
return nil, fmt.Errorf("cannot get aws secret key: %v", err) return nil, fmt.Errorf("cannot get aws secret key: %v", err)
} }
region, err := provider.prompt("Region:") region, err := getter.prompt("Region:")
if err != nil { if err != nil {
return nil, fmt.Errorf("cannot get aws region: %v", err) return nil, fmt.Errorf("cannot get aws region: %v", err)
} }
@ -192,17 +190,17 @@ func (provider *promptKMSProvider) getAWSConfig() (*AwsConfig, error) {
} }
// prompt prompt the user to input a string for a certain field with timeout. // prompt prompt the user to input a string for a certain field with timeout.
func (provider *promptKMSProvider) prompt(hint string) (string, error) { func (getter *promptACGetter) prompt(hint string) (string, error) {
var ( var (
res string res string
err error err error
finished = make(chan struct{}) finished = make(chan struct{})
timedOut = time.After(provider.timeout) timedOut = time.After(getter.timeout)
) )
go func() { go func() {
res, err = provider.threadedPrompt(hint) res, err = getter.threadedPrompt(hint)
close(finished) close(finished)
}() }()
@ -216,12 +214,12 @@ func (provider *promptKMSProvider) prompt(hint string) (string, error) {
} }
} }
func (provider *promptKMSProvider) threadedPrompt(hint string) (string, error) { func (getter *promptACGetter) threadedPrompt(hint string) (string, error) {
console.print(hint) console.print(hint)
return console.readPassword() return console.readPassword()
} }
func (provider *promptKMSProvider) toStr() string { func (getter *promptACGetter) String() string {
return "prompt" return "prompt"
} }

@ -20,41 +20,13 @@ func LoadKeys(cfg Config) (multibls.PrivateKeys, error) {
return helper.loadKeys() return helper.loadKeys()
} }
func getHelper(cfg Config) (loadHelper, error) {
switch {
case stringIsSet(cfg.BlsKeyFile):
switch filepath.Ext(*cfg.BlsKeyFile) {
case basicKeyExt:
return &basicSingleBlsLoader{
blsKeyFile: *cfg.BlsKeyFile,
passProviderConfig: cfg.getPassProviderConfig(),
}, nil
case kmsKeyExt:
return &kmsSingleBlsLoader{
blsKeyFile: *cfg.BlsKeyFile,
kmsProviderConfig: cfg.getKmsProviderConfig(),
}, nil
default:
return nil, errors.New("unknown extension")
}
case stringIsSet(cfg.BlsDir):
return &blsDirLoader{
dirPath: *cfg.BlsDir,
passProviderConfig: cfg.getPassProviderConfig(),
kmsProviderConfig: cfg.getKmsProviderConfig(),
}, nil
default:
return nil, errors.New("either BlsKeyFile or BlsDir must be set")
}
}
// Loader is the structure to load bls keys. // Loader is the structure to load bls keys.
type Config struct { type Config struct {
// source for bls key loading. At least one of the BlsKeyFile and BlsDir // source for bls key loading. At least one of the BlsKeyFile and BlsDir
// need to be provided. // need to be provided.
// //
// BlsKeyFile defines a single key file to load from. Based on the file // BlsKeyFile defines a single key file to load from. Based on the file
// extension, decryption with passphrase or aws kms will be used. // extension, decryption with either passphrase or aws kms will be used.
BlsKeyFile *string BlsKeyFile *string
// BlsDir defines a file directory to load keys from. // BlsDir defines a file directory to load keys from.
BlsDir *string BlsDir *string
@ -134,3 +106,31 @@ func (cfg *Config) getKmsProviderConfig() kmsProviderConfig {
awsConfigFile: cfg.AwsConfigFile, awsConfigFile: cfg.AwsConfigFile,
} }
} }
func getHelper(cfg Config) (loadHelper, error) {
switch {
case stringIsSet(cfg.BlsKeyFile):
switch filepath.Ext(*cfg.BlsKeyFile) {
case basicKeyExt:
return &basicSingleBlsLoader{
blsKeyFile: *cfg.BlsKeyFile,
passProviderConfig: cfg.getPassProviderConfig(),
}, nil
case kmsKeyExt:
return &kmsSingleBlsLoader{
blsKeyFile: *cfg.BlsKeyFile,
kmsProviderConfig: cfg.getKmsProviderConfig(),
}, nil
default:
return nil, errors.New("unknown extension")
}
case stringIsSet(cfg.BlsDir):
return &blsDirLoader{
dirPath: *cfg.BlsDir,
passProviderConfig: cfg.getPassProviderConfig(),
kmsProviderConfig: cfg.getKmsProviderConfig(),
}, nil
default:
return nil, errors.New("either BlsKeyFile or BlsDir must be set")
}
}

@ -25,8 +25,8 @@ func (config passProviderConfig) validate() error {
// passProvider is the interface to provide the passphrase of a bls keys. // passProvider is the interface to provide the passphrase of a bls keys.
// Implemented by // Implemented by
// promptPassProvider - provide passphrase through user-interactive prompt // promptPassProvider - provide passphrase through user-interactive prompt
// filePassProvider - provide passphrase from a .pass file // filePassProvider - provide passphrase from a .pass file
// dirPassProvider - provide passphrase from .pass files in a directory // dirPassProvider - provide passphrase from .pass files in a directory
type passProvider interface { type passProvider interface {
toStr() string toStr() string
getPassphrase(keyFile string) (string, error) getPassphrase(keyFile string) (string, error)

@ -47,7 +47,7 @@ func loadBasicKeyWithProvider(blsKeyFile string, pp passProvider) (*bls_core.Sec
} }
// loadKmsKeyFromFile loads a single KMS BLS key from file // loadKmsKeyFromFile loads a single KMS BLS key from file
func loadKmsKeyFromFile(blsKeyFile string, kcp kmsClientProvider) (*bls_core.SecretKey, error) { func loadKmsKeyFromFile(blsKeyFile string, kcp kmsProvider) (*bls_core.SecretKey, error) {
if kcp == nil { if kcp == nil {
return nil, errNilKMSClientProvider return nil, errNilKMSClientProvider
} }

Loading…
Cancel
Save