Skip to content

Commit 26e1ce2

Browse files
authored
Merge pull request #702 from smallstep/herman/awskms-decrypter
Add `CreateDecrypter` support to AWS KMS
2 parents 43a6d36 + 3018ea0 commit 26e1ce2

File tree

4 files changed

+377
-1
lines changed

4 files changed

+377
-1
lines changed

kms/awskms/awskms.go

+1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ type KeyManagementClient interface {
3434
CreateKey(ctx context.Context, input *kms.CreateKeyInput, opts ...func(*kms.Options)) (*kms.CreateKeyOutput, error)
3535
CreateAlias(ctx context.Context, input *kms.CreateAliasInput, opts ...func(*kms.Options)) (*kms.CreateAliasOutput, error)
3636
Sign(ctx context.Context, input *kms.SignInput, opts ...func(*kms.Options)) (*kms.SignOutput, error)
37+
Decrypt(ctx context.Context, params *kms.DecryptInput, optFns ...func(*kms.Options)) (*kms.DecryptOutput, error)
3738
}
3839

3940
// customerMasterKeySpecMapping is a mapping between the step signature algorithm,

kms/awskms/decrypter.go

+171
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
package awskms
2+
3+
import (
4+
"crypto"
5+
"crypto/rsa"
6+
"errors"
7+
"fmt"
8+
"io"
9+
10+
"github.com/aws/aws-sdk-go-v2/service/kms"
11+
"github.com/aws/aws-sdk-go-v2/service/kms/types"
12+
13+
"go.step.sm/crypto/kms/apiv1"
14+
"go.step.sm/crypto/pemutil"
15+
)
16+
17+
// CreateDecrypter implements the [apiv1.Decrypter] interface and returns
18+
// a [crypto.Decrypter] backed by a decryption key in AWS KMS.
19+
func (k *KMS) CreateDecrypter(req *apiv1.CreateDecrypterRequest) (crypto.Decrypter, error) {
20+
if req.DecryptionKey == "" {
21+
return nil, errors.New("decryption key cannot be empty")
22+
}
23+
24+
return NewDecrypter(k.client, req.DecryptionKey)
25+
}
26+
27+
// Decrypter implements a [crypto.Decrypter] using AWS KMS.
28+
type Decrypter struct {
29+
client KeyManagementClient
30+
keyID string
31+
publicKey crypto.PublicKey
32+
}
33+
34+
// NewDecrypter creates a new [crypto.Decrypter] backed by the given
35+
// AWS KMS. decryption key.
36+
func NewDecrypter(client KeyManagementClient, decryptionKey string) (*Decrypter, error) {
37+
keyID, err := parseKeyID(decryptionKey)
38+
if err != nil {
39+
return nil, err
40+
}
41+
42+
decrypter := &Decrypter{
43+
client: client,
44+
keyID: keyID,
45+
}
46+
if err := decrypter.preloadKey(); err != nil {
47+
return nil, err
48+
}
49+
50+
return decrypter, nil
51+
}
52+
53+
func (d *Decrypter) preloadKey() error {
54+
ctx, cancel := defaultContext()
55+
defer cancel()
56+
57+
resp, err := d.client.GetPublicKey(ctx, &kms.GetPublicKeyInput{
58+
KeyId: pointer(d.keyID),
59+
})
60+
if err != nil {
61+
return fmt.Errorf("awskms GetPublicKey failed: %w", err)
62+
}
63+
64+
d.publicKey, err = pemutil.ParseDER(resp.PublicKey)
65+
return err
66+
}
67+
68+
// Public returns the public key of this decrypter
69+
func (d *Decrypter) Public() crypto.PublicKey {
70+
return d.publicKey
71+
}
72+
73+
// Decrypt decrypts ciphertext using the decryption key backed by AWS KMS and returns
74+
// the plaintext bytes. An error is returned when decryption fails. AWS KMS only supports
75+
// RSA keys with 2048, 3072 or 4096 bits and will always use OAEP. It supports SHA1 and SHA256.
76+
// Labels are not supported. Before calling out to AWS, some validation is performed
77+
// so that known bad parameters are detected client-side and a more meaningful error is returned
78+
// for those cases.
79+
//
80+
// Also see https://docs.aws.amazon.com/kms/latest/developerguide/symm-asymm-choose-key-spec.html#key-spec-rsa.
81+
func (d *Decrypter) Decrypt(_ io.Reader, ciphertext []byte, opts crypto.DecrypterOpts) ([]byte, error) {
82+
algorithm, err := determineDecryptionAlgorithm(d.publicKey, opts)
83+
if err != nil {
84+
return nil, fmt.Errorf("failed determining decryption algorithm: %w", err)
85+
}
86+
87+
req := &kms.DecryptInput{
88+
KeyId: pointer(d.keyID),
89+
CiphertextBlob: ciphertext,
90+
EncryptionAlgorithm: algorithm,
91+
}
92+
93+
ctx, cancel := defaultContext()
94+
defer cancel()
95+
96+
response, err := d.client.Decrypt(ctx, req)
97+
if err != nil {
98+
return nil, fmt.Errorf("awskms Decrypt failed: %w", err)
99+
}
100+
101+
return response.Plaintext, nil
102+
}
103+
104+
const (
105+
awsOaepSha1 = "RSAES_OAEP_SHA_1"
106+
awsOaepSha256 = "RSAES_OAEP_SHA_256"
107+
)
108+
109+
func determineDecryptionAlgorithm(key crypto.PublicKey, opts crypto.DecrypterOpts) (types.EncryptionAlgorithmSpec, error) {
110+
pub, ok := key.(*rsa.PublicKey)
111+
if !ok {
112+
return "", fmt.Errorf("awskms does not support key type %T", key)
113+
}
114+
115+
if opts == nil {
116+
opts = &rsa.OAEPOptions{}
117+
}
118+
119+
var rsaOpts *rsa.OAEPOptions
120+
switch o := opts.(type) {
121+
case *rsa.OAEPOptions:
122+
if err := validateOAEPOptions(o); err != nil {
123+
return "", err
124+
}
125+
rsaOpts = o
126+
case *rsa.PKCS1v15DecryptOptions:
127+
return "", errors.New("awskms does not support PKCS #1 v1.5 decryption")
128+
default:
129+
return "", fmt.Errorf("invalid decrypter options type %T", opts)
130+
}
131+
132+
switch bitSize := pub.Size() * 8; bitSize {
133+
default:
134+
return "", fmt.Errorf("awskms does not support RSA public key size %d", bitSize)
135+
case 2048, 3072, 4096:
136+
switch rsaOpts.Hash {
137+
case crypto.SHA1:
138+
return awsOaepSha1, nil
139+
case crypto.SHA256:
140+
return awsOaepSha256, nil
141+
case crypto.Hash(0):
142+
// set a sane default hashing algorithm when it's not set. AWS KMS only supports
143+
// SHA1 and SHA256, so using SHA256 generally shouldn't result in a decryption
144+
// operation breaking, but it depends on the sending side whether or not this
145+
// is the correct value. If it's not provided through opts, then there's no other
146+
// way to determine which algorithm to use, though, so this is an optimistic attempt
147+
// at decryption.
148+
return awsOaepSha256, nil
149+
default:
150+
return "", fmt.Errorf("awskms does not support hash algorithm %q with RSA-OAEP", rsaOpts.Hash)
151+
}
152+
}
153+
}
154+
155+
// validateOAEPOptions validates the RSA OAEP options provided.
156+
func validateOAEPOptions(o *rsa.OAEPOptions) error {
157+
if len(o.Label) > 0 {
158+
return errors.New("awskms does not support RSA-OAEP label")
159+
}
160+
161+
switch {
162+
case o.Hash != 0 && o.MGFHash == 0: // assumes same hash is being used for both
163+
break
164+
case o.Hash != 0 && o.MGFHash != 0 && o.Hash != o.MGFHash:
165+
return fmt.Errorf("awskms does not support using different algorithms for hashing %q and masking %q", o.Hash, o.MGFHash)
166+
}
167+
168+
return nil
169+
}
170+
171+
var _ apiv1.Decrypter = (*KMS)(nil)

kms/awskms/decrypter_test.go

+185
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
package awskms
2+
3+
import (
4+
"context"
5+
"crypto"
6+
"crypto/rand"
7+
"crypto/rsa"
8+
"crypto/sha1"
9+
"crypto/sha256"
10+
"encoding/pem"
11+
"fmt"
12+
"hash"
13+
"testing"
14+
15+
"github.com/aws/aws-sdk-go-v2/service/kms"
16+
"github.com/stretchr/testify/require"
17+
18+
"go.step.sm/crypto/kms/apiv1"
19+
"go.step.sm/crypto/pemutil"
20+
)
21+
22+
func TestCreateDecrypter(t *testing.T) {
23+
key, err := pemutil.ParseKey([]byte(rsaPublicKey))
24+
require.NoError(t, err)
25+
require.IsType(t, &rsa.PublicKey{}, key)
26+
rsaKey := key.(*rsa.PublicKey)
27+
28+
k := &KMS{client: &MockClient{
29+
getPublicKey: func(ctx context.Context, input *kms.GetPublicKeyInput, opts ...func(*kms.Options)) (*kms.GetPublicKeyOutput, error) {
30+
block, _ := pem.Decode([]byte(rsaPublicKey))
31+
return &kms.GetPublicKeyOutput{
32+
KeyId: input.KeyId,
33+
PublicKey: block.Bytes,
34+
}, nil
35+
},
36+
}}
37+
38+
// fail with empty decryption key
39+
d, err := k.CreateDecrypter(&apiv1.CreateDecrypterRequest{
40+
DecryptionKey: "",
41+
})
42+
require.Error(t, err)
43+
require.Nil(t, d)
44+
45+
// expect same public key to be returned
46+
d, err = k.CreateDecrypter(&apiv1.CreateDecrypterRequest{
47+
DecryptionKey: "test",
48+
})
49+
require.NoError(t, err)
50+
require.NotNil(t, d)
51+
require.True(t, rsaKey.Equal(d.Public()))
52+
}
53+
54+
func TestDecrypterDecrypts(t *testing.T) {
55+
kms, pub := createTestKMS(t, 2048)
56+
fail1024KMS, _ := createTestKMS(t, 1024)
57+
58+
// prepare encrypted contents
59+
encSHA256, err := rsa.EncryptOAEP(sha256.New(), rand.Reader, pub, []byte("test"), nil)
60+
require.NoError(t, err)
61+
encSHA1, err := rsa.EncryptOAEP(sha1.New(), rand.Reader, pub, []byte("test"), nil)
62+
require.NoError(t, err)
63+
64+
// create a decrypter, identified by "test-sha256", and check the public key
65+
d256, err := kms.CreateDecrypter(&apiv1.CreateDecrypterRequest{
66+
DecryptionKey: "test-sha256",
67+
})
68+
require.NoError(t, err)
69+
require.NotNil(t, d256)
70+
require.True(t, pub.Equal(d256.Public()))
71+
72+
// create a decrypter, identified by "test-sha1", and check the public key
73+
d1, err := kms.CreateDecrypter(&apiv1.CreateDecrypterRequest{
74+
DecryptionKey: "test-sha1",
75+
})
76+
require.NoError(t, err)
77+
require.NotNil(t, d1)
78+
require.True(t, pub.Equal(d1.Public()))
79+
80+
t.Run("ok/sha256", func(t *testing.T) {
81+
// successful decryption using OAEP with SHA-256
82+
plain, err := d256.Decrypt(nil, encSHA256, &rsa.OAEPOptions{Hash: crypto.SHA256})
83+
require.NoError(t, err)
84+
require.Equal(t, []byte("test"), plain)
85+
})
86+
87+
t.Run("ok/sha1", func(t *testing.T) {
88+
// successful decryption using OAEP with SHA-1
89+
plain, err := d1.Decrypt(nil, encSHA1, &rsa.OAEPOptions{Hash: crypto.SHA1})
90+
require.NoError(t, err)
91+
require.Equal(t, []byte("test"), plain)
92+
})
93+
94+
t.Run("ok/default-options", func(t *testing.T) {
95+
// successful decryption, defaulting to OAEP with SHA-256
96+
plain, err := d256.Decrypt(nil, encSHA256, nil)
97+
require.NoError(t, err)
98+
require.Equal(t, []byte("test"), plain)
99+
})
100+
101+
t.Run("fail/hash", func(t *testing.T) {
102+
plain, err := d256.Decrypt(nil, encSHA256, &rsa.OAEPOptions{Hash: crypto.SHA384})
103+
require.EqualError(t, err, `failed determining decryption algorithm: awskms does not support hash algorithm "SHA-384" with RSA-OAEP`)
104+
require.Empty(t, plain)
105+
})
106+
107+
t.Run("fail/label", func(t *testing.T) {
108+
plain, err := d256.Decrypt(nil, encSHA256, &rsa.OAEPOptions{Hash: crypto.SHA256, Label: []byte{1, 2, 3, 4}})
109+
require.EqualError(t, err, "failed determining decryption algorithm: awskms does not support RSA-OAEP label")
110+
require.Empty(t, plain)
111+
})
112+
113+
t.Run("fail/hash-mismatch", func(t *testing.T) {
114+
plain, err := d256.Decrypt(nil, encSHA256, &rsa.OAEPOptions{Hash: crypto.SHA256, MGFHash: crypto.SHA384})
115+
require.EqualError(t, err, `failed determining decryption algorithm: awskms does not support using different algorithms for hashing "SHA-256" and masking "SHA-384"`)
116+
require.Empty(t, plain)
117+
})
118+
119+
t.Run("fail/pkcs15", func(t *testing.T) {
120+
plain, err := d256.Decrypt(nil, encSHA256, &rsa.PKCS1v15DecryptOptions{})
121+
require.EqualError(t, err, "failed determining decryption algorithm: awskms does not support PKCS #1 v1.5 decryption")
122+
require.Empty(t, plain)
123+
})
124+
125+
t.Run("fail/invalid-options", func(t *testing.T) {
126+
plain, err := d256.Decrypt(nil, encSHA256, struct{}{})
127+
require.EqualError(t, err, "failed determining decryption algorithm: invalid decrypter options type struct {}")
128+
require.Empty(t, plain)
129+
})
130+
131+
t.Run("fail/invalid-key", func(t *testing.T) {
132+
failingDecrypter, err := fail1024KMS.CreateDecrypter(&apiv1.CreateDecrypterRequest{
133+
DecryptionKey: "fail",
134+
})
135+
require.NoError(t, err)
136+
137+
_, err = failingDecrypter.Decrypt(nil, nil, nil)
138+
require.EqualError(t, err, "failed determining decryption algorithm: awskms does not support RSA public key size 1024")
139+
})
140+
}
141+
142+
func createTestKMS(t *testing.T, bitSize int) (*KMS, *rsa.PublicKey) {
143+
t.Helper()
144+
145+
key, err := rsa.GenerateKey(rand.Reader, bitSize)
146+
require.NoError(t, err)
147+
148+
k := &KMS{client: &MockClient{
149+
getPublicKey: func(ctx context.Context, input *kms.GetPublicKeyInput, opts ...func(*kms.Options)) (*kms.GetPublicKeyOutput, error) {
150+
block, _ := pemutil.Serialize(key.Public())
151+
return &kms.GetPublicKeyOutput{
152+
KeyId: input.KeyId,
153+
PublicKey: block.Bytes,
154+
}, nil
155+
},
156+
decrypt: func(ctx context.Context, params *kms.DecryptInput, optFns ...func(*kms.Options)) (*kms.DecryptOutput, error) {
157+
var h hash.Hash
158+
switch *params.KeyId {
159+
case "test-sha256":
160+
if params.EncryptionAlgorithm != "RSAES_OAEP_SHA_256" {
161+
return nil, fmt.Errorf("invalid encryption algorithm %q", params.EncryptionAlgorithm)
162+
}
163+
h = sha256.New()
164+
case "test-sha1":
165+
if params.EncryptionAlgorithm != "RSAES_OAEP_SHA_1" {
166+
return nil, fmt.Errorf("invalid encryption algorithm %q", params.EncryptionAlgorithm)
167+
}
168+
h = sha1.New()
169+
default:
170+
return nil, fmt.Errorf("invalid key ID %q", *params.KeyId)
171+
}
172+
173+
dec, err := rsa.DecryptOAEP(h, nil, key, params.CiphertextBlob, nil)
174+
if err != nil {
175+
return nil, err
176+
}
177+
return &kms.DecryptOutput{
178+
KeyId: params.KeyId,
179+
Plaintext: dec,
180+
}, nil
181+
},
182+
}}
183+
184+
return k, &key.PublicKey
185+
}

0 commit comments

Comments
 (0)