diff --git a/crypto.go b/crypto.go index 43c0846..58a038b 100644 --- a/crypto.go +++ b/crypto.go @@ -17,6 +17,7 @@ import ( "crypto/sha256" "fmt" "math/big" + "regexp" "strconv" "strings" @@ -37,42 +38,38 @@ var ( okmSize = 48 ) -// PrivateKeyFromSeedAndPath generates a private key given a seed and a path. -// Follows ERC-2334. -func PrivateKeyFromSeedAndPath(seed []byte, path string) (*e2types.BLSPrivateKey, error) { - if path == "" { - return nil, errors.New("no path") +func validateRelativePath(relativePath string) bool { + match, _ := regexp.MatchString(`^(\/(\d\d?\d?\d?\d?\d?))+$`, relativePath) + return match +} + +func validateMasterKeyPath(path string) bool { + match, _ := regexp.MatchString(`^m\/`, path) + return match +} + +func PrivateKeyForRelativePath(key []byte, relativePath string) (*e2types.BLSPrivateKey, error) { + if validateMasterKeyPath(relativePath) { + return nil, fmt.Errorf("basePath invalid, not relative. if you need to derive basePath from seed please use PrivateKeyFromSeedAndPath") } - if len(seed) < 16 { - return nil, errors.New("seed must be at least 128 bits") + if !validateRelativePath(relativePath) { + return nil, fmt.Errorf("relative basePath invalid: %s", relativePath) } - pathBits := strings.Split(path, "/") - var sk *big.Int - var err error - for i := range pathBits { + + pathBits := strings.Split(relativePath, "/") + sk := new(big.Int).SetBytes(key) + for i := 1 ; i < len(pathBits) ; i++ { // we skip index 0 as it's empty for relative paths if pathBits[i] == "" { - return nil, fmt.Errorf("no entry at path component %d", i) + return nil, fmt.Errorf("no entry at basePath component %d", i) + } + + index, err := strconv.ParseInt(pathBits[i], 10, 32) + if err != nil || index < 0 { + return nil, fmt.Errorf("invalid index %q at basePath component %d", pathBits[i], i) } - if pathBits[i] == "m" { - if i != 0 { - return nil, fmt.Errorf("invalid master at path component %d", i) - } - sk, err = DeriveMasterSK(seed) - if err != nil { - return nil, errors.Wrapf(err, "failed to generate master key at path component %d", i) - } - } else { - if i == 0 { - return nil, fmt.Errorf("not master at path component %d", i) - } - index, err := strconv.ParseInt(pathBits[i], 10, 32) - if err != nil || index < 0 { - return nil, fmt.Errorf("invalid index %q at path component %d", pathBits[i], i) - } - sk, err = DeriveChildSK(sk, uint32(index)) - if err != nil { - return nil, errors.Wrapf(err, "failed to derive child SK at path component %d", i) - } + sk, err = DeriveChildSK(sk, uint32(index)) + if err != nil { + return nil, errors.Wrapf(err, "failed to derive child SK at basePath component %d", i) } } @@ -84,6 +81,30 @@ func PrivateKeyFromSeedAndPath(seed []byte, path string) (*e2types.BLSPrivateKey return e2types.BLSPrivateKeyFromBytes(bytes) } +// PrivateKeyFromSeedAndPath generates a private key given a seed and a basePath. +// Follows ERC-2334. +func PrivateKeyFromSeedAndPath(seed []byte, path string) (*e2types.BLSPrivateKey, error) { + if len(seed) < 16 { + return nil, errors.New("seed must be at least 128 bits") + } + if !validateMasterKeyPath(path) { + return nil,fmt.Errorf("invalid basePath, should start with m/") + } + + // derive master key + sk, err := DeriveMasterSK(seed) + if err != nil { + return nil, errors.Wrapf(err, "failed to generate master key") + } + + if len(path) > 2 { // try to derive child keys + relativePath := strings.Replace(path,"m","",1) + return PrivateKeyForRelativePath(sk.Bytes(),relativePath) + } else { // derive just seed to master + return e2types.BLSPrivateKeyFromBytes(sk.Bytes()) + } +} + // DeriveMasterSK derives the master secret key from a seed. // Follows ERC-2333. func DeriveMasterSK(seed []byte) (*big.Int, error) { diff --git a/crypto_test.go b/crypto_test.go index 2060aaa..e494a83 100644 --- a/crypto_test.go +++ b/crypto_test.go @@ -43,6 +43,106 @@ func TestMain(m *testing.M) { os.Exit(m.Run()) } +func TestPrivateKeyForRelativePath(t *testing.T) { + tests := []struct { + name string + seed []byte + basePath string + relativePath string + err error + sk *big.Int + }{ + { + name: "Nil", + seed: _byteArray("0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f"), + err: errors.New("invalid basePath, should start with m/"), + }, + { + name: "EmptyPath", + basePath: "", + seed: _byteArray("0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f"), + err: errors.New("invalid basePath, should start with m/"), + }, + { + name: "EmptySeed", + basePath: "m/", + relativePath: "/12381/3600/0/0", + err: errors.New("seed must be at least 128 bits"), + }, + { + name: "BadPath1", + seed: _byteArray("0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f"), + basePath: "m/bad basePath", + err: errors.New(`relative basePath invalid: /bad basePath`), + }, + { + name: "BadPath2", + seed: _byteArray("0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f"), + basePath: "m/m/12381", + err: errors.New(`relative basePath invalid: /m/12381`), + }, + { + name: "BadPath3", + seed: _byteArray("0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f"), + basePath: "1/m/12381", + err: errors.New(`invalid basePath, should start with m/`), + }, + { + name: "BadPath4", + seed: _byteArray("0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f"), + basePath: "m/12381//0", + err: errors.New(`relative basePath invalid: /12381//0`), + }, + { + name: "BadPath5", + seed: _byteArray("0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f"), + basePath: "m/12381/-1/0", + err: errors.New(`relative basePath invalid: /12381/-1/0`), + }, + { + name: "Good1", + seed: _byteArray("0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f"), + basePath: "m/", + relativePath: "/12381/3600/0/0", + sk: _bigInt("31676788419929922777864946442677915531199062343799598297489487887255736884383"), + }, + { + name: "good2", + seed: _byteArray("0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f"), + basePath: "m/12381", + relativePath: "/3600/0/0", + sk: _bigInt("31676788419929922777864946442677915531199062343799598297489487887255736884383"), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + // derive master key + sk, err := util.PrivateKeyFromSeedAndPath(test.seed, test.basePath) + if test.err != nil { + require.NotNil(t, err) + assert.Equal(t, test.err.Error(), err.Error()) + } else { + require.Nil(t, err) + } + + if err != nil { + return + } + + // derive relative path + sk, err = util.PrivateKeyForRelativePath(sk.Marshal(),test.relativePath) + if test.err != nil { + require.NotNil(t, err) + assert.Equal(t, test.err.Error(), err.Error()) + } else { + require.Nil(t, err) + assert.Equal(t, test.sk.Bytes(), sk.Marshal()) + } + }) + } +} + func TestPrivateKeyFromSeedAndPath(t *testing.T) { tests := []struct { name string @@ -53,12 +153,14 @@ func TestPrivateKeyFromSeedAndPath(t *testing.T) { }{ { name: "Nil", - err: errors.New("no path"), + seed: _byteArray("0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f"), + err: errors.New("invalid basePath, should start with m/"), }, { name: "EmptyPath", path: "", - err: errors.New("no path"), + seed: _byteArray("0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f"), + err: errors.New("invalid basePath, should start with m/"), }, { name: "EmptySeed", @@ -68,32 +170,32 @@ func TestPrivateKeyFromSeedAndPath(t *testing.T) { { name: "BadPath1", seed: _byteArray("0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f"), - path: "m/bad path", - err: errors.New(`invalid index "bad path" at path component 1`), + path: "m/bad basePath", + err: errors.New(`relative basePath invalid: /bad basePath`), }, { name: "BadPath2", seed: _byteArray("0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f"), path: "m/m/12381", - err: errors.New(`invalid master at path component 1`), + err: errors.New(`relative basePath invalid: /m/12381`), }, { name: "BadPath3", seed: _byteArray("0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f"), path: "1/m/12381", - err: errors.New(`not master at path component 0`), + err: errors.New(`invalid basePath, should start with m/`), }, { name: "BadPath4", seed: _byteArray("0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f"), path: "m/12381//0", - err: errors.New(`no entry at path component 2`), + err: errors.New(`relative basePath invalid: /12381//0`), }, { name: "BadPath5", seed: _byteArray("0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f"), path: "m/12381/-1/0", - err: errors.New(`invalid index "-1" at path component 2`), + err: errors.New(`relative basePath invalid: /12381/-1/0`), }, { name: "Good1",