You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
381 lines
9.9 KiB
381 lines
9.9 KiB
package blsgen
|
|
|
|
import (
|
|
"bytes"
|
|
"errors"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"os"
|
|
"path/filepath"
|
|
"reflect"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/ethereum/go-ethereum/common"
|
|
"github.com/harmony-one/harmony/crypto/bls"
|
|
)
|
|
|
|
func TestNewPassDecrypter(t *testing.T) {
|
|
// setup
|
|
var (
|
|
testDir = filepath.Join(baseTestDir, t.Name())
|
|
existPassFile = filepath.Join(testDir, testKeys[0].publicKey+passExt)
|
|
emptyPassFile = filepath.Join(testDir, testKeys[1].publicKey+passExt)
|
|
)
|
|
if err := writeFile(existPassFile, testKeys[0].passphrase); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
tests := []struct {
|
|
config passDecrypterConfig
|
|
expErr error
|
|
providerTypes []passProvider
|
|
}{
|
|
{
|
|
config: passDecrypterConfig{passSrcType: PassSrcNil},
|
|
expErr: errors.New("unknown PassSrcType"),
|
|
},
|
|
{
|
|
config: passDecrypterConfig{
|
|
passSrcType: PassSrcFile,
|
|
passFile: &emptyPassFile,
|
|
},
|
|
expErr: errors.New("no such file or directory"),
|
|
},
|
|
{
|
|
config: passDecrypterConfig{
|
|
passSrcType: PassSrcFile,
|
|
passFile: &existPassFile,
|
|
},
|
|
expErr: nil,
|
|
providerTypes: []passProvider{
|
|
&staticPassProvider{},
|
|
},
|
|
},
|
|
{
|
|
config: passDecrypterConfig{passSrcType: PassSrcPrompt},
|
|
expErr: nil,
|
|
providerTypes: []passProvider{
|
|
&promptPassProvider{},
|
|
},
|
|
},
|
|
{
|
|
config: passDecrypterConfig{
|
|
passSrcType: PassSrcPrompt,
|
|
persistPassphrase: true,
|
|
},
|
|
expErr: nil,
|
|
providerTypes: []passProvider{
|
|
&promptPassProvider{},
|
|
},
|
|
},
|
|
{
|
|
config: passDecrypterConfig{
|
|
passSrcType: PassSrcAuto,
|
|
},
|
|
expErr: nil,
|
|
providerTypes: []passProvider{
|
|
&dynamicPassProvider{},
|
|
&promptPassProvider{},
|
|
},
|
|
},
|
|
{
|
|
config: passDecrypterConfig{
|
|
passSrcType: PassSrcAuto,
|
|
passFile: &existPassFile,
|
|
},
|
|
expErr: nil,
|
|
providerTypes: []passProvider{
|
|
&staticPassProvider{},
|
|
&promptPassProvider{},
|
|
},
|
|
},
|
|
}
|
|
for i, test := range tests {
|
|
decrypter, err := newPassDecrypter(test.config)
|
|
|
|
if assErr := assertError(err, test.expErr); assErr != nil {
|
|
t.Errorf("Test %v: %v", i, assErr)
|
|
continue
|
|
}
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
if len(decrypter.pps) != len(test.providerTypes) {
|
|
t.Errorf("Test %v: unexpected provider number %v / %v",
|
|
i, len(decrypter.pps), len(test.providerTypes))
|
|
continue
|
|
}
|
|
for ppIndex, gotPP := range decrypter.pps {
|
|
gotType := reflect.TypeOf(gotPP).Elem()
|
|
expType := reflect.TypeOf(test.providerTypes[ppIndex]).Elem()
|
|
if gotType != expType {
|
|
t.Errorf("Test %v: %v passProvider unexpected type: %v / %v",
|
|
i, ppIndex, gotType, expType)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestPassDecrypter_decryptFile(t *testing.T) {
|
|
setTestConsole(newTestConsole())
|
|
unitTestDir := filepath.Join(baseTestDir, t.Name())
|
|
tests := []struct {
|
|
setupFunc func(rootDir string) error
|
|
providers []passProvider
|
|
keyFile string
|
|
|
|
expPublicKey string
|
|
expErr error
|
|
}{
|
|
{
|
|
// 1. Two providers, one return err and one return the correct passphrase.
|
|
setupFunc: func(rootDir string) error {
|
|
keyFile := filepath.Join(rootDir, testKeys[1].publicKey+basicKeyExt)
|
|
return writeFile(keyFile, testKeys[1].keyFileData)
|
|
},
|
|
providers: []passProvider{
|
|
&errPassProvider{},
|
|
makeTestPassProvider(),
|
|
},
|
|
keyFile: testKeys[1].publicKey + basicKeyExt,
|
|
expPublicKey: testKeys[1].publicKey,
|
|
},
|
|
{
|
|
// 2. Only error provider. Return the decryption error
|
|
setupFunc: func(rootDir string) error {
|
|
keyFile := filepath.Join(rootDir, testKeys[1].publicKey+basicKeyExt)
|
|
return writeFile(keyFile, testKeys[1].keyFileData)
|
|
},
|
|
providers: []passProvider{&errPassProvider{}},
|
|
keyFile: testKeys[1].publicKey + basicKeyExt,
|
|
expErr: errors.New("failed to load bls key"),
|
|
},
|
|
}
|
|
for i, test := range tests {
|
|
tcDir := filepath.Join(unitTestDir, fmt.Sprintf("%v", i))
|
|
os.RemoveAll(tcDir)
|
|
os.MkdirAll(tcDir, 0700)
|
|
if test.setupFunc != nil {
|
|
if err := test.setupFunc(tcDir); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
keyFile := filepath.Join(tcDir, test.keyFile)
|
|
|
|
decrypter := &passDecrypter{pps: test.providers}
|
|
secret, err := decrypter.decryptFile(keyFile)
|
|
|
|
if assErr := assertError(err, test.expErr); assErr != nil {
|
|
t.Errorf("Test %v: %v", i, assErr)
|
|
}
|
|
if err != nil || test.expErr != nil {
|
|
continue
|
|
}
|
|
gotPub := bls.FromLibBLSPublicKeyUnsafe(secret.GetPublicKey())[:]
|
|
if expPub := common.Hex2Bytes(test.expPublicKey); !bytes.Equal(gotPub, expPub) {
|
|
t.Errorf("Test %v: unexpected public key %v / %v", i, gotPub, expPub)
|
|
}
|
|
}
|
|
}
|
|
|
|
type testPassProvider struct {
|
|
m map[string]string
|
|
}
|
|
|
|
func makeTestPassProvider() *testPassProvider {
|
|
return &testPassProvider{
|
|
m: map[string]string{
|
|
testKeys[0].publicKey: testKeys[0].passphrase,
|
|
testKeys[1].publicKey: testKeys[1].passphrase,
|
|
},
|
|
}
|
|
}
|
|
|
|
func (provider *testPassProvider) getPassphrase(keyFile string) (string, error) {
|
|
basename := filepath.Base(keyFile)
|
|
publicKey := strings.TrimSuffix(basename, filepath.Ext(basename))
|
|
pass, exist := provider.m[publicKey]
|
|
if !exist {
|
|
return "", errors.New("passphrase not exist")
|
|
}
|
|
return pass, nil
|
|
}
|
|
|
|
type errPassProvider struct{}
|
|
|
|
func (provider *errPassProvider) getPassphrase(keyFile string) (string, error) {
|
|
return "", errors.New("error intended")
|
|
}
|
|
|
|
func TestPromptPassProvider_getPassphrase(t *testing.T) {
|
|
unitTestDir := filepath.Join(baseTestDir, t.Name())
|
|
tests := []struct {
|
|
setupFunc func(rootDir string) error
|
|
keyFile string
|
|
passphrase string
|
|
enablePersist bool
|
|
extraInput []string
|
|
|
|
expOutputLen int
|
|
expErr error
|
|
newPassFileContent bool
|
|
passFileExist bool
|
|
}{
|
|
{
|
|
setupFunc: nil,
|
|
keyFile: testKeys[1].publicKey + basicKeyExt,
|
|
passphrase: "new key",
|
|
enablePersist: false,
|
|
extraInput: []string{},
|
|
expOutputLen: 1, // prompt for passphrase
|
|
passFileExist: false,
|
|
newPassFileContent: false,
|
|
},
|
|
{
|
|
// new pass file
|
|
setupFunc: nil,
|
|
keyFile: testKeys[1].publicKey + basicKeyExt,
|
|
passphrase: "new key",
|
|
enablePersist: true,
|
|
extraInput: []string{},
|
|
expOutputLen: 1, // prompt for passphrase
|
|
passFileExist: true,
|
|
newPassFileContent: true,
|
|
},
|
|
{
|
|
// exist pass file, do not overwrite
|
|
setupFunc: func(rootDir string) error {
|
|
passFile := filepath.Join(rootDir, testKeys[1].publicKey+passExt)
|
|
return writeFile(passFile, "old key")
|
|
},
|
|
keyFile: testKeys[1].publicKey + basicKeyExt,
|
|
passphrase: "new key",
|
|
enablePersist: true,
|
|
extraInput: []string{"n"},
|
|
expOutputLen: 2, // prompt for passphrase and ask for overwrite
|
|
passFileExist: true,
|
|
newPassFileContent: false,
|
|
},
|
|
{
|
|
// exist pass file, do overwrite
|
|
setupFunc: func(rootDir string) error {
|
|
passFile := filepath.Join(rootDir, testKeys[1].publicKey+passExt)
|
|
return writeFile(passFile, "old key")
|
|
},
|
|
keyFile: testKeys[1].publicKey + basicKeyExt,
|
|
passphrase: "new key",
|
|
enablePersist: true,
|
|
extraInput: []string{"y"},
|
|
expOutputLen: 2, // prompt for passphrase and ask for overwrite
|
|
passFileExist: true,
|
|
newPassFileContent: true,
|
|
},
|
|
}
|
|
|
|
for i, test := range tests {
|
|
tc := newTestConsole()
|
|
setTestConsole(tc)
|
|
tcDir := filepath.Join(unitTestDir, fmt.Sprintf("%v", i))
|
|
os.RemoveAll(tcDir)
|
|
os.MkdirAll(tcDir, 0700)
|
|
if test.setupFunc != nil {
|
|
if err := test.setupFunc(tcDir); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
tc.In <- test.passphrase
|
|
for _, in := range test.extraInput {
|
|
tc.In <- in
|
|
}
|
|
|
|
ppd := &promptPassProvider{enablePersist: test.enablePersist}
|
|
keyFile := filepath.Join(tcDir, test.keyFile)
|
|
passphrase, err := ppd.getPassphrase(keyFile)
|
|
|
|
if assErr := assertError(err, test.expErr); assErr != nil {
|
|
t.Errorf("Test %v: %v", i, assErr)
|
|
continue
|
|
}
|
|
if passphrase != test.passphrase {
|
|
t.Errorf("Test %v: got unexpected passphrase: %v / %v", i, passphrase, test.passphrase)
|
|
continue
|
|
}
|
|
for index := 0; index != test.expOutputLen; index++ {
|
|
<-tc.Out
|
|
}
|
|
if isClean, msg := tc.checkClean(); !isClean {
|
|
t.Errorf("Test %v: console not clean: %v", i, msg)
|
|
continue
|
|
}
|
|
passFile := keyFileToPassFileFull(keyFile)
|
|
if !test.passFileExist {
|
|
if _, err := os.Stat(passFile); !os.IsNotExist(err) {
|
|
t.Errorf("Test %v: pass file exist %v", i, passFile)
|
|
}
|
|
} else {
|
|
b, err := ioutil.ReadFile(passFile)
|
|
if err != nil {
|
|
t.Error(err)
|
|
continue
|
|
}
|
|
if test.newPassFileContent && string(b) != test.passphrase {
|
|
t.Errorf("Test %v: unexpected passphrase from persist file: %v/ %v",
|
|
i, string(b), test.passphrase)
|
|
}
|
|
if !test.newPassFileContent && string(b) == test.passphrase {
|
|
t.Errorf("Test %v: passphrase content has changed", i)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestDynamicPassProvider_getPassPhrase(t *testing.T) {
|
|
unitTestDir := filepath.Join(baseTestDir, t.Name())
|
|
|
|
tests := []struct {
|
|
setupFunc func(rootDir string) error
|
|
keyFile string
|
|
expPass string
|
|
expErr error
|
|
}{
|
|
{
|
|
setupFunc: func(rootDir string) error {
|
|
passFile := filepath.Join(rootDir, testKeys[1].publicKey+passExt)
|
|
return writeFile(passFile, "passphrase\n")
|
|
},
|
|
keyFile: testKeys[1].publicKey + basicKeyExt,
|
|
expPass: "passphrase",
|
|
},
|
|
{
|
|
keyFile: testKeys[1].publicKey + basicKeyExt,
|
|
expErr: errors.New("no such file"),
|
|
},
|
|
}
|
|
for i, test := range tests {
|
|
tcDir := filepath.Join(unitTestDir, fmt.Sprintf("%v", i))
|
|
os.RemoveAll(tcDir)
|
|
os.MkdirAll(tcDir, 0700)
|
|
|
|
if test.setupFunc != nil {
|
|
if err := test.setupFunc(tcDir); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
provider := &dynamicPassProvider{}
|
|
keyFile := filepath.Join(tcDir, test.keyFile)
|
|
got, err := provider.getPassphrase(keyFile)
|
|
|
|
if assErr := assertError(err, test.expErr); assErr != nil {
|
|
t.Errorf("Test %v: %v", i, assErr)
|
|
continue
|
|
}
|
|
if err != nil {
|
|
continue
|
|
}
|
|
if got != test.expPass {
|
|
t.Errorf("Test %v: unexpected passphrase: %v / %v", i, got, test.expPass)
|
|
}
|
|
}
|
|
}
|
|
|