From 66a3d8afd1d5faa1109b9a048d8f728da3213b43 Mon Sep 17 00:00:00 2001 From: Simon Ser Date: Wed, 1 Nov 2023 00:22:43 +0100 Subject: [PATCH 01/18] Allow VAPID sub to be an HTTPS URL The RFC says: > The "sub" claim SHOULD include a contact URI for the application > server as either a "mailto:" (email) [RFC6068] or an "https:" > [RFC2818] URI. However the library assumes the passed in subscribed is always an e-mail address, without leaving a way to pass an HTTPS URL. --- vapid.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/vapid.go b/vapid.go index fe2c580..58e6173 100644 --- a/vapid.go +++ b/vapid.go @@ -8,6 +8,7 @@ import ( "fmt" "math/big" "net/url" + "strings" "time" "github.com/golang-jwt/jwt" @@ -72,10 +73,15 @@ func getVAPIDAuthorizationHeader( return "", err } + // Unless subscriber is an HTTPS URL, assume an e-mail address + if !strings.HasPrefix(subscriber, "https:") { + subscriber = fmt.Sprintf("mailto:%s", subscriber) + } + token := jwt.NewWithClaims(jwt.SigningMethodES256, jwt.MapClaims{ "aud": fmt.Sprintf("%s://%s", subURL.Scheme, subURL.Host), "exp": expiration.Unix(), - "sub": fmt.Sprintf("mailto:%s", subscriber), + "sub": subscriber, }) // Decode the VAPID private key From dcf6893e6518e4ce7960d702862305745eff17e5 Mon Sep 17 00:00:00 2001 From: Willi Schinmeyer Date: Mon, 6 Nov 2023 16:01:53 +0100 Subject: [PATCH 02/18] Add end to end test This new test mocks both the user agent (e.g. browser) and the push service (e.g. Firestore) to verify that encryption and decryption works properly. I used the RFCs as reference (RFC8291, RFC8292 & RFC 8188), but didn't follow them to the letter. The result can successfully check all the signatures and decrypt the content, so it seems to be working. Instead of the deprecated crypto/elliptic functions, this makes heavy use of crypto/ecdh, which require Go 1.20. But this is only a test dependency, library users should not be impacted. --- end2end_test.go | 462 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 462 insertions(+) create mode 100644 end2end_test.go diff --git a/end2end_test.go b/end2end_test.go new file mode 100644 index 0000000..f7858f4 --- /dev/null +++ b/end2end_test.go @@ -0,0 +1,462 @@ +package webpush + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/ecdh" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/binary" + "fmt" + "io" + "math/big" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/golang-jwt/jwt" + "golang.org/x/crypto/hkdf" +) + +func TestEnd2End(t *testing.T) { + var ( + // the data known to the application server (backend, which uses webpush-go) + applicationServer struct { + publicVAPIDKey string + privateVAPIDKey string + subscription Subscription + } + // the data known to the user agent (browser) + userAgent struct { + publicVAPIDKey *ecdsa.PublicKey + subscriptionKey *ecdsa.PrivateKey + authSecret [16]byte + subscription Subscription + receivedNotifications [][]byte + } + // the data known to the push server (which receives push messages on behalf of the user agent, e.g. Firestore) + pushService struct { + applicationServerKey *ecdsa.PublicKey + receivedNotifications [][]byte + } + + err error + ) + + // a VAPID key pair for the application server, usually only generated once and reused + applicationServer.privateVAPIDKey, applicationServer.publicVAPIDKey, err = GenerateVAPIDKeys() + if err != nil { + t.Fatalf("generating VAPID keys: %s", err) + } + + // The application server needs to inform the user agent of the public VAPID key. + // (We decode it first for ease of use.) + userAgent.publicVAPIDKey, err = decodeVAPIDPublicKey(applicationServer.publicVAPIDKey) + if err != nil { + t.Fatal(err) + } + + // We need a mock push service for webpush-go to send notifications to. + var mockPushService *httptest.Server + mockPushService = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // check that there's a valid vapid JWT + token, err := parseVapidAuthHeader( + r.Header.Get("Authorization"), + // by the time this function is called, this value will be set (see PushManager.subscribe() below) + pushService.applicationServerKey) + if err != nil { + w.WriteHeader(http.StatusUnauthorized) + _, _ = fmt.Fprintf(w, "invalid auth: %s", err) + return + } + // verify that the audience matches our URL + aud := token.Claims.(jwt.MapClaims)["aud"] + if aud != mockPushService.URL { + w.WriteHeader(http.StatusUnauthorized) + _, _ = fmt.Fprintf(w, "JWT has bad audience, want %q, got %q", mockPushService.URL, aud) + return + } + // RFC8188 only allows for exactly one content encoding + if contentEncoding := r.Header.Get("Content-Encoding"); contentEncoding != "aes128gcm" { + w.WriteHeader(http.StatusBadRequest) + _, _ = fmt.Fprintf(w, "unsupported Content-Encoding, want %q, got %q", "aes128gcm", contentEncoding) + return + } + body, err := io.ReadAll(r.Body) + if err != nil { + // this suggests a broken connection, so log the error instead of sending it back + t.Errorf("failed to read request body: %s", err) + w.WriteHeader(http.StatusInternalServerError) + return + } + // store body for later decoding by user agent + // (the push service doesn't have the key required for decryption) + pushService.receivedNotifications = append(pushService.receivedNotifications, body) + + w.WriteHeader(http.StatusAccepted) + })) + defer mockPushService.Close() + + // what follows is the equivalent of PushManager.subscribe() in JS + { + // the user agent generates its own key pair so it can be sent encrypted messages + userAgent.subscriptionKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("generating user agent keys: %s", err) + } + // we need the ECDH representation + ecdhPublicKey, err := userAgent.subscriptionKey.PublicKey.ECDH() + if err != nil { + t.Fatalf("converting user agent public key to ECDH: %s", err) + } + // generate the shared auth secret + _, err = rand.Read(userAgent.authSecret[:]) + if err != nil { + t.Fatalf("generating user agent auth secret: %s", err) + } + // the user agent then performs a registration with the push service using that key, + // while also letting the push service know the application server key to expect. + pushService.applicationServerKey = userAgent.publicVAPIDKey + userAgent.subscription = Subscription{ + Keys: Keys{ + Auth: base64.StdEncoding.EncodeToString(userAgent.authSecret[:]), + P256dh: base64.StdEncoding.EncodeToString(ecdhPublicKey.Bytes()), + }, + Endpoint: mockPushService.URL, + } + } + + // the user agent sends its subscription to the application server... + applicationServer.subscription = userAgent.subscription + + // ...and the application server uses the subscription to send a push notification + sentMessage := "this is our test push notification" + resp, err := SendNotification([]byte(sentMessage), &applicationServer.subscription, &Options{ + HTTPClient: mockPushService.Client(), + VAPIDPublicKey: applicationServer.publicVAPIDKey, + VAPIDPrivateKey: applicationServer.privateVAPIDKey, + Subscriber: "test@example.com", + }) + if err != nil { + t.Fatalf("failed to send notification: %s", err) + } + defer func() { + if err := resp.Body.Close(); err != nil { + t.Errorf("error closing mock push service response body: %s", err) + } + }() + // check for success + respBody, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("reading mock push service response body: %s", err) + } + if resp.StatusCode/100 != 2 { + t.Errorf("unexpected push service status code %d, body: %s", resp.StatusCode, respBody) + } + + // the push server should now have received the notification + if l := len(pushService.receivedNotifications); l != 1 { + t.Fatalf("Want 1 notification received by push service, got %d", l) + } + // the push service then forwards the notification to the user agent + userAgent.receivedNotifications = pushService.receivedNotifications + // and the user agent can decrypt them + receivedMessage, err := decodeNotification(userAgent.receivedNotifications[0], userAgent.authSecret, userAgent.subscriptionKey) + if err != nil { + t.Fatalf("error decrypting notification in user agent: %s", err) + } + if receivedMessage != sentMessage { + t.Errorf("Sent notification %q, but got %q", sentMessage, receivedMessage) + } +} + +func decodeVAPIDPublicKey(publicVAPIDKey string) (*ecdsa.PublicKey, error) { + publicVAPIDKeyBytes, err := base64.RawURLEncoding.DecodeString(publicVAPIDKey) + if err != nil { + return nil, fmt.Errorf("base64-decoding public VAPID key: %w", err) + } + return decodeECDSAPublicKey(publicVAPIDKeyBytes) +} + +func decodeECDSAPublicKey(bytes []byte) (*ecdsa.PublicKey, error) { + ecdhKey, err := ecdh.P256().NewPublicKey(bytes) + if err != nil { + return nil, fmt.Errorf("parsing public VAPID key: %w", err) + } + res, err := ecdhPublicKeyToECDSA(ecdhKey) + if err != nil { + return nil, fmt.Errorf("converting public VAPID key from *ecdh.PublicKey to *ecdsa.PublicKey: %w", err) + } + return res, nil +} + +func ecdhPublicKeyToECDSA(key *ecdh.PublicKey) (*ecdsa.PublicKey, error) { + // see https://github.com/golang/go/issues/63963 + rawKey := key.Bytes() + switch key.Curve() { + case ecdh.P256(): + return &ecdsa.PublicKey{ + Curve: elliptic.P256(), + X: big.NewInt(0).SetBytes(rawKey[1:33]), + Y: big.NewInt(0).SetBytes(rawKey[33:]), + }, nil + case ecdh.P384(): + return &ecdsa.PublicKey{ + Curve: elliptic.P384(), + X: big.NewInt(0).SetBytes(rawKey[1:49]), + Y: big.NewInt(0).SetBytes(rawKey[49:]), + }, nil + case ecdh.P521(): + return &ecdsa.PublicKey{ + Curve: elliptic.P521(), + X: big.NewInt(0).SetBytes(rawKey[1:67]), + Y: big.NewInt(0).SetBytes(rawKey[67:]), + }, nil + default: + return nil, fmt.Errorf("cannot convert non-NIST *ecdh.PublicKey to *ecdsa.PublicKey") + } +} + +func Test_ecdhPublicKeyToECDSA(t *testing.T) { + tests := [...]struct { + name string + curve elliptic.Curve + }{ + // P224 not supported by ecdh + { + name: "P256", + curve: elliptic.P256(), + }, + { + name: "P256", + curve: elliptic.P384(), + }, + { + name: "P521", + curve: elliptic.P521(), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pk, err := ecdsa.GenerateKey(tt.curve, rand.Reader) + if err != nil { + t.Fatalf("generating ecdsa.PrivateKey: %s", err) + } + original := &pk.PublicKey + converted, err := original.ECDH() + if err != nil { + t.Fatalf("converting ecdsa.PublicKey to ecdh.PublicKey: %s", err) + } + roundtrip, err := ecdhPublicKeyToECDSA(converted) + if err != nil { + t.Fatalf("converting ecdh.PublicKey back to ecdsa.PublicKey: %s", err) + } + if !roundtrip.Equal(original) { + t.Errorf("Roundtrip changed key from %v to %v", original, roundtrip) + } + }) + } +} + +func parseVapidAuthHeader(authHeader string, applicationServerKey *ecdsa.PublicKey) (*jwt.Token, error) { + if authHeader == "" { + return nil, fmt.Errorf("missing auth header") + } + // the Authorization header should be of the form "vapid t=JWT, k=key" (RFC8292) + // we need to extract the JWT (JSON Web Token) from t to check the signature using k + authBody, found := strings.CutPrefix(authHeader, "vapid ") + if !found { + return nil, fmt.Errorf("Authorization header is not vapid: %s", authHeader) + } + authFields := strings.Split(authBody, ",") + rawJWT := "" + rawKey := "" + for _, field := range authFields { + kv := strings.SplitN(field, "=", 2) + if len(kv) < 2 { + return nil, fmt.Errorf("push service vapid Authorization header field %q malformed", field) + } + key := strings.TrimSpace(kv[0]) + val := strings.TrimSpace(kv[1]) + switch key { + case "t": + rawJWT = val + case "k": + rawKey = val + default: + // other fields irrelevant to us + } + } + if rawJWT == "" { + return nil, fmt.Errorf("vapid Authorization header lacks \"t\" field (JWT)") + } + if rawKey == "" { + return nil, fmt.Errorf("vapid Authorization header lacks \"k\" field") + } + key, err := decodeVAPIDPublicKey(rawKey) + if err != nil { + return nil, fmt.Errorf("parsing vapid Authorization key: %w", err) + } + // check that the key matches the known applicationServerKey + // (RFC8292 4.2) + if !key.Equal(applicationServerKey) { + // in real code, this would mean the user agent needs to resubscribe with the new applicationServerKey + return nil, fmt.Errorf("vapid Authorization key does not match applicationServerKey from subscription") + } + + // verify the JWT signature + token, err := parseJWT(rawJWT, key) + if err != nil { + return nil, fmt.Errorf("parsing vapid Authorization JWT: %w", err) + } + return token, nil +} + +func parseJWT(rawJWT string, applicationServerKey *ecdsa.PublicKey) (*jwt.Token, error) { + token, err := jwt.Parse(rawJWT, func(t *jwt.Token) (interface{}, error) { + switch t.Method.Alg() { + case "ES256": + return applicationServerKey, nil + default: + return nil, fmt.Errorf("unsupported JWT signing alg %q", t.Method.Alg()) + } + }) + if err != nil { + return nil, fmt.Errorf("decoding JWT %s: %w", rawJWT, err) + } + return token, nil +} + +func decodeNotification(body []byte, authSecret [16]byte, userAgentKey *ecdsa.PrivateKey) (string, error) { + // remember initial body length, before we start consuming it + bodyLen := len(body) + // the body is aes128gcm-encoded as described in RFC8188, + // starting with this header: + // +-----------+--------+-----------+---------------+ + // | salt (16) | rs (4) | idlen (1) | keyid (idlen) | + // +-----------+--------+-----------+---------------+ + salt, body := body[:16], body[16:] + recordSize, body := int(binary.BigEndian.Uint32(body[:4])), body[4:] + idLen, body := int(uint8(body[0])), body[1:] + rawPubKey, body := body[:idLen], body[idLen:] + if bodyLen != recordSize { + // this could mean a multi-record message was sent, this simplified parser does not support those. + return "", fmt.Errorf("expected body length %d, got %d", recordSize, bodyLen) + } + + // parse keys and derive shared secret + pubKey, err := decodeECDSAPublicKey(rawPubKey) + if err != nil { + return "", fmt.Errorf("decoding public key from header: %w", err) + } + pubKeyECDH, err := pubKey.ECDH() + if err != nil { + return "", fmt.Errorf("converting public key to ECDH: %w", err) + } + userAgentECDHKey, err := userAgentKey.ECDH() + if err != nil { + return "", fmt.Errorf("converting user agent private key to ECDH: %w", err) + } + userAgentECDHPublicKey, err := userAgentKey.PublicKey.ECDH() + if err != nil { + return "", fmt.Errorf("converting user agent public key to ECDH: %w", err) + } + + sharedECDHSecret, err := userAgentECDHKey.ECDH(pubKeyECDH) + if err != nil { + return "", fmt.Errorf("deriving shared secret from notification public key and user agent private key: %w", err) + } + + hash := sha256.New + + // ikm + prkInfoBuf := bytes.NewBuffer([]byte("WebPush: info\x00")) + prkInfoBuf.Write(userAgentECDHPublicKey.Bytes()) // aka "dh" + prkInfoBuf.Write(pubKeyECDH.Bytes()) + + prkHKDF := hkdf.New(hash, sharedECDHSecret, authSecret[:], prkInfoBuf.Bytes()) + ikm, err := getHKDFKey(prkHKDF, 32) + if err != nil { + return "", fmt.Errorf("deriving ikm: %w", err) + } + + // Derive Content Encryption Key + contentEncryptionKeyInfo := []byte("Content-Encoding: aes128gcm\x00") + contentHKDF := hkdf.New(hash, ikm, salt, contentEncryptionKeyInfo) + contentEncryptionKey, err := getHKDFKey(contentHKDF, 16) + if err != nil { + return "", fmt.Errorf("deriving content encryption key: %w", err) + } + + // Derive the Nonce + nonceInfo := []byte("Content-Encoding: nonce\x00") + nonceHKDF := hkdf.New(hash, ikm, salt, nonceInfo) + nonce, err := getHKDFKey(nonceHKDF, 12) + if err != nil { + return "", fmt.Errorf("deriving nonce: %w", err) + } + + // Cipher + c, err := aes.NewCipher(contentEncryptionKey) + if err != nil { + return "", fmt.Errorf("creating cipher block: %w", err) + } + + gcm, err := cipher.NewGCM(c) + if err != nil { + return "", fmt.Errorf("creating GCM: %w", err) + } + + // Decrypt + res, err := gcm.Open(nil, nonce, body, nil) + if err != nil { + return "", fmt.Errorf("decrypting: %w", err) + } + + // the message is padded with 0x02 0x00 0x00 0x00 [...] 0x00, we need to remove that + lastNull := len(res) + for ; lastNull > 0 && res[lastNull-1] == 0x00; lastNull-- { + } + if lastNull == 0 { + // we expect at least one 0x02 (or 0x01) before the nulls, not finding one is wrong + return "", fmt.Errorf("decryption yielded only %d null bytes", len(res)) + } + if beforeNull := res[lastNull-1]; beforeNull != 0x02 { + // if we get an 0x01, it means we have a multi-record message, this mock does not implement those + return "", fmt.Errorf("padding nulls in decrypted message should be preceded by 0x02 delimiter, got %02X", beforeNull) + } + // strip trailing nulls and separating 0x02 + res = res[:lastNull-1] + + return string(res), nil +} + +// test for the decoding helper function +func Test_decodeVAPIDPublicKey(t *testing.T) { + privKeyB64, pubKeyB64, err := GenerateVAPIDKeys() + if err != nil { + t.Fatalf("generating VAPID keys: %s", err) + } + + // as a baseline, decode using the library functions + privKeyBytes, err := decodeVapidKey(privKeyB64) + if err != nil { + t.Fatalf("decoding private key: %s", err) + } + privKey := generateVAPIDHeaderKeys(privKeyBytes) + wantPubKey := &privKey.PublicKey + + // now decode using our test helper and compare the results + gotPubKey, err := decodeVAPIDPublicKey(pubKeyB64) + if err != nil { + t.Fatalf("decoding public key") + } + if !gotPubKey.Equal(wantPubKey) { + t.Errorf("result differs:\ngot: %v\nwant: %v", gotPubKey, wantPubKey) + } +} From 3e0c7552ae08118117f79a6c97af5a50e90c2bd7 Mon Sep 17 00:00:00 2001 From: Willi Schinmeyer Date: Tue, 7 Nov 2023 11:52:40 +0100 Subject: [PATCH 03/18] Replace deprecated crypto/elliptic with crypto/ecdh crypto/elliptic is subject for removal soon, use of crypto/ecdh is advised instead. This has a number of side-effects: - the required Go version increases to Go 1.20 - the configured VAPID keys get verified now, and invalid keys are rejected - this necessitated changes to some tests - VAPID is effectively mandatory now (but all push services I know require it anyway) go.mod has been updated to reflect the new requirement, and I ran `go mod tidy` to clean up go.sum. I also added additional error context by wrapping errors with fmt.Errorf's %w verb. This was introduced in Go 1.13. --- README.md | 2 +- end2end_test.go | 74 +++------------------------------------- go.mod | 2 +- go.sum | 34 ------------------ vapid.go | 81 ++++++++++++++++++++++++++++--------------- vapid_test.go | 91 +++++++++++++++++++++++++++++++++++++++++++++++-- webpush.go | 41 ++++++++++------------ webpush_test.go | 16 ++++++--- 8 files changed, 179 insertions(+), 162 deletions(-) diff --git a/README.md b/README.md index c313fc6..3a0ab13 100644 --- a/README.md +++ b/README.md @@ -54,7 +54,7 @@ if err != nil { ## Development -1. Install [Go 1.11+](https://golang.org/) +1. Install [Go 1.20+](https://golang.org/) 2. `go mod vendor` 3. `go test` diff --git a/end2end_test.go b/end2end_test.go index f7858f4..40cbc72 100644 --- a/end2end_test.go +++ b/end2end_test.go @@ -13,7 +13,6 @@ import ( "encoding/binary" "fmt" "io" - "math/big" "net/http" "net/http/httptest" "strings" @@ -195,74 +194,6 @@ func decodeECDSAPublicKey(bytes []byte) (*ecdsa.PublicKey, error) { return res, nil } -func ecdhPublicKeyToECDSA(key *ecdh.PublicKey) (*ecdsa.PublicKey, error) { - // see https://github.com/golang/go/issues/63963 - rawKey := key.Bytes() - switch key.Curve() { - case ecdh.P256(): - return &ecdsa.PublicKey{ - Curve: elliptic.P256(), - X: big.NewInt(0).SetBytes(rawKey[1:33]), - Y: big.NewInt(0).SetBytes(rawKey[33:]), - }, nil - case ecdh.P384(): - return &ecdsa.PublicKey{ - Curve: elliptic.P384(), - X: big.NewInt(0).SetBytes(rawKey[1:49]), - Y: big.NewInt(0).SetBytes(rawKey[49:]), - }, nil - case ecdh.P521(): - return &ecdsa.PublicKey{ - Curve: elliptic.P521(), - X: big.NewInt(0).SetBytes(rawKey[1:67]), - Y: big.NewInt(0).SetBytes(rawKey[67:]), - }, nil - default: - return nil, fmt.Errorf("cannot convert non-NIST *ecdh.PublicKey to *ecdsa.PublicKey") - } -} - -func Test_ecdhPublicKeyToECDSA(t *testing.T) { - tests := [...]struct { - name string - curve elliptic.Curve - }{ - // P224 not supported by ecdh - { - name: "P256", - curve: elliptic.P256(), - }, - { - name: "P256", - curve: elliptic.P384(), - }, - { - name: "P521", - curve: elliptic.P521(), - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - pk, err := ecdsa.GenerateKey(tt.curve, rand.Reader) - if err != nil { - t.Fatalf("generating ecdsa.PrivateKey: %s", err) - } - original := &pk.PublicKey - converted, err := original.ECDH() - if err != nil { - t.Fatalf("converting ecdsa.PublicKey to ecdh.PublicKey: %s", err) - } - roundtrip, err := ecdhPublicKeyToECDSA(converted) - if err != nil { - t.Fatalf("converting ecdh.PublicKey back to ecdsa.PublicKey: %s", err) - } - if !roundtrip.Equal(original) { - t.Errorf("Roundtrip changed key from %v to %v", original, roundtrip) - } - }) - } -} - func parseVapidAuthHeader(authHeader string, applicationServerKey *ecdsa.PublicKey) (*jwt.Token, error) { if authHeader == "" { return nil, fmt.Errorf("missing auth header") @@ -448,7 +379,10 @@ func Test_decodeVAPIDPublicKey(t *testing.T) { if err != nil { t.Fatalf("decoding private key: %s", err) } - privKey := generateVAPIDHeaderKeys(privKeyBytes) + privKey, err := generateVAPIDHeaderKeys(privKeyBytes) + if err != nil { + t.Fatalf("converting private key: %s", err) + } wantPubKey := &privKey.PublicKey // now decode using our test helper and compare the results diff --git a/go.mod b/go.mod index 6b0604f..642b7dd 100644 --- a/go.mod +++ b/go.mod @@ -5,4 +5,4 @@ require ( golang.org/x/crypto v0.9.0 ) -go 1.13 +go 1.20 diff --git a/go.sum b/go.sum index d9575c4..a2e2828 100644 --- a/go.sum +++ b/go.sum @@ -1,38 +1,4 @@ github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= -github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g= golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= -golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= -golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= -golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= -golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= -golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= -golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/vapid.go b/vapid.go index fe2c580..9e82e96 100644 --- a/vapid.go +++ b/vapid.go @@ -1,6 +1,7 @@ package webpush import ( + "crypto/ecdh" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" @@ -16,46 +17,69 @@ import ( // GenerateVAPIDKeys will create a private and public VAPID key pair func GenerateVAPIDKeys() (privateKey, publicKey string, err error) { // Get the private key from the P256 curve - curve := elliptic.P256() + curve := ecdh.P256() - private, x, y, err := elliptic.GenerateKey(curve, rand.Reader) + private, err := curve.GenerateKey(rand.Reader) if err != nil { return } - public := elliptic.Marshal(curve, x, y) - // Convert to base64 - publicKey = base64.RawURLEncoding.EncodeToString(public) - privateKey = base64.RawURLEncoding.EncodeToString(private) - + publicKey = base64.RawURLEncoding.EncodeToString(private.PublicKey().Bytes()) + privateKey = base64.RawURLEncoding.EncodeToString(private.Bytes()) return } // Generates the ECDSA public and private keys for the JWT encryption -func generateVAPIDHeaderKeys(privateKey []byte) *ecdsa.PrivateKey { - // Public key - curve := elliptic.P256() - px, py := curve.ScalarMult( - curve.Params().Gx, - curve.Params().Gy, - privateKey, - ) - - pubKey := ecdsa.PublicKey{ - Curve: curve, - X: px, - Y: py, +func generateVAPIDHeaderKeys(privateKey []byte) (*ecdsa.PrivateKey, error) { + key, err := ecdh.P256().NewPrivateKey(privateKey) + if err != nil { + return nil, fmt.Errorf("validating private key: %w", err) } + converted, err := ecdhPrivateKeyToECDSA(key) + if err != nil { + return nil, fmt.Errorf("converting private key to crypto/ecdsa: %w", err) + } + return converted, nil +} - // Private key - d := &big.Int{} - d.SetBytes(privateKey) +func ecdhPublicKeyToECDSA(key *ecdh.PublicKey) (*ecdsa.PublicKey, error) { + // see https://github.com/golang/go/issues/63963 + rawKey := key.Bytes() + switch key.Curve() { + case ecdh.P256(): + return &ecdsa.PublicKey{ + Curve: elliptic.P256(), + X: big.NewInt(0).SetBytes(rawKey[1:33]), + Y: big.NewInt(0).SetBytes(rawKey[33:]), + }, nil + case ecdh.P384(): + return &ecdsa.PublicKey{ + Curve: elliptic.P384(), + X: big.NewInt(0).SetBytes(rawKey[1:49]), + Y: big.NewInt(0).SetBytes(rawKey[49:]), + }, nil + case ecdh.P521(): + return &ecdsa.PublicKey{ + Curve: elliptic.P521(), + X: big.NewInt(0).SetBytes(rawKey[1:67]), + Y: big.NewInt(0).SetBytes(rawKey[67:]), + }, nil + default: + return nil, fmt.Errorf("cannot convert non-NIST *ecdh.PublicKey to *ecdsa.PublicKey") + } +} - return &ecdsa.PrivateKey{ - PublicKey: pubKey, - D: d, +func ecdhPrivateKeyToECDSA(key *ecdh.PrivateKey) (*ecdsa.PrivateKey, error) { + // see https://github.com/golang/go/issues/63963 + pubKey, err := ecdhPublicKeyToECDSA(key.PublicKey()) + if err != nil { + return nil, fmt.Errorf("converting PublicKey part of *ecdh.PrivateKey: %w", err) } + return &ecdsa.PrivateKey{ + PublicKey: *pubKey, + D: big.NewInt(0).SetBytes(key.Bytes()), + }, nil } // getVAPIDAuthorizationHeader @@ -84,7 +108,10 @@ func getVAPIDAuthorizationHeader( return "", err } - privKey := generateVAPIDHeaderKeys(decodedVapidPrivateKey) + privKey, err := generateVAPIDHeaderKeys(decodedVapidPrivateKey) + if err != nil { + return "", fmt.Errorf("generating VAPID header keys: %w", err) + } // Sign token with private key jwtString, err := token.SignedString(privKey) diff --git a/vapid_test.go b/vapid_test.go index be4cca8..d77d4ec 100644 --- a/vapid_test.go +++ b/vapid_test.go @@ -1,6 +1,9 @@ package webpush import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" "encoding/base64" "fmt" "strings" @@ -44,10 +47,13 @@ func TestVAPID(t *testing.T) { b64 := base64.RawURLEncoding decodedVapidPrivateKey, err := b64.DecodeString(vapidPrivateKey) if err != nil { - t.Fatal("Could not decode VAPID private key") + t.Fatalf("Could not decode VAPID private key: %s", err) } - privKey := generateVAPIDHeaderKeys(decodedVapidPrivateKey) + privKey, err := generateVAPIDHeaderKeys(decodedVapidPrivateKey) + if err != nil { + t.Fatalf("Could not parse VAPID private key: %s", err) + } return privKey.Public(), nil }) @@ -100,3 +106,84 @@ func getTokenFromAuthorizationHeader(tokenHeader string, t *testing.T) string { return tsplit[1][:len(tsplit[1])-1] } + +func Test_ecdhPublicKeyToECDSA(t *testing.T) { + tests := [...]struct { + name string + curve elliptic.Curve + }{ + // P224 not supported by ecdh + { + name: "P256", + curve: elliptic.P256(), + }, + { + name: "P256", + curve: elliptic.P384(), + }, + { + name: "P521", + curve: elliptic.P521(), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pk, err := ecdsa.GenerateKey(tt.curve, rand.Reader) + if err != nil { + t.Fatalf("generating ecdsa.PrivateKey: %s", err) + } + original := &pk.PublicKey + converted, err := original.ECDH() + if err != nil { + t.Fatalf("converting ecdsa.PublicKey to ecdh.PublicKey: %s", err) + } + roundtrip, err := ecdhPublicKeyToECDSA(converted) + if err != nil { + t.Fatalf("converting ecdh.PublicKey back to ecdsa.PublicKey: %s", err) + } + if !roundtrip.Equal(original) { + t.Errorf("Roundtrip changed key from %v to %v", original, roundtrip) + } + }) + } +} + +func Test_ecdhPrivateKeyToECDSA(t *testing.T) { + tests := [...]struct { + name string + curve elliptic.Curve + }{ + // P224 not supported by ecdh + { + name: "P256", + curve: elliptic.P256(), + }, + { + name: "P256", + curve: elliptic.P384(), + }, + { + name: "P521", + curve: elliptic.P521(), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + original, err := ecdsa.GenerateKey(tt.curve, rand.Reader) + if err != nil { + t.Fatalf("generating ecdsa.PrivateKey: %s", err) + } + converted, err := original.ECDH() + if err != nil { + t.Fatalf("converting ecdsa.PrivateKey to ecdh.PrivateKey: %s", err) + } + roundtrip, err := ecdhPrivateKeyToECDSA(converted) + if err != nil { + t.Fatalf("converting ecdh.PrivateKey back to ecdsa.PrivateKey: %s", err) + } + if !roundtrip.Equal(original) { + t.Errorf("Roundtrip changed key from %v to %v", original, roundtrip) + } + }) + } +} diff --git a/webpush.go b/webpush.go index 4c85ad6..a2d9768 100644 --- a/webpush.go +++ b/webpush.go @@ -5,12 +5,13 @@ import ( "context" "crypto/aes" "crypto/cipher" - "crypto/elliptic" + "crypto/ecdh" "crypto/rand" "crypto/sha256" "encoding/base64" "encoding/binary" "errors" + "fmt" "io" "net/http" "strconv" @@ -77,53 +78,47 @@ func SendNotificationWithContext(ctx context.Context, message []byte, s *Subscri // Authentication secret (auth_secret) authSecret, err := decodeSubscriptionKey(s.Keys.Auth) if err != nil { - return nil, err + return nil, fmt.Errorf("decoding keys.auth: %w", err) } // dh (Diffie Hellman) dh, err := decodeSubscriptionKey(s.Keys.P256dh) if err != nil { - return nil, err + return nil, fmt.Errorf("decoding keys.p256dh: %w", err) + } + userAgentPublicKey, err := ecdh.P256().NewPublicKey(dh) + if err != nil { + return nil, fmt.Errorf("validating keys.p256dh: %w", err) } // Generate 16 byte salt salt, err := saltFunc() if err != nil { - return nil, err + return nil, fmt.Errorf("generating salt: %w", err) } // Create the ecdh_secret shared key pair - curve := elliptic.P256() // Application server key pairs (single use) - localPrivateKey, x, y, err := elliptic.GenerateKey(curve, rand.Reader) + localPrivateKey, err := ecdh.P256().GenerateKey(rand.Reader) if err != nil { return nil, err } - localPublicKey := elliptic.Marshal(curve, x, y) - - // Combine application keys with receiver's EC public key - sharedX, sharedY := elliptic.Unmarshal(curve, dh) - if sharedX == nil { - return nil, errors.New("Unmarshal Error: Public key is not a valid point on the curve") - } + localPublicKey := localPrivateKey.PublicKey() - // Derive ECDH shared secret - sx, sy := curve.ScalarMult(sharedX, sharedY, localPrivateKey) - if !curve.IsOnCurve(sx, sy) { - return nil, errors.New("Encryption error: ECDH shared secret isn't on curve") + // Combine application keys with receiver's EC public key to derive ECDH shared secret + sharedECDHSecret, err := localPrivateKey.ECDH(userAgentPublicKey) + if err != nil { + return nil, fmt.Errorf("deriving shared secret: %w", err) } - mlen := curve.Params().BitSize / 8 - sharedECDHSecret := make([]byte, mlen) - sx.FillBytes(sharedECDHSecret) hash := sha256.New // ikm prkInfoBuf := bytes.NewBuffer([]byte("WebPush: info\x00")) prkInfoBuf.Write(dh) - prkInfoBuf.Write(localPublicKey) + prkInfoBuf.Write(localPublicKey.Bytes()) prkHKDF := hkdf.New(hash, sharedECDHSecret, authSecret, prkInfoBuf.Bytes()) ikm, err := getHKDFKey(prkHKDF, 32) @@ -173,8 +168,8 @@ func SendNotificationWithContext(ctx context.Context, message []byte, s *Subscri binary.BigEndian.PutUint32(rs, recordSize) recordBuf.Write(rs) - recordBuf.Write([]byte{byte(len(localPublicKey))}) - recordBuf.Write(localPublicKey) + recordBuf.Write([]byte{byte(len(localPublicKey.Bytes()))}) + recordBuf.Write(localPublicKey.Bytes()) // Data dataBuf := bytes.NewBuffer(message) diff --git a/webpush_test.go b/webpush_test.go index 807a1f7..d1f74c5 100644 --- a/webpush_test.go +++ b/webpush_test.go @@ -33,6 +33,10 @@ func getStandardEncodedTestSubscription() *Subscription { } func TestSendNotificationToURLEncodedSubscription(t *testing.T) { + priv, pub, err := GenerateVAPIDKeys() + if err != nil { + t.Fatal(err) + } resp, err := SendNotification([]byte("Test"), getURLEncodedTestSubscription(), &Options{ HTTPClient: &testHTTPClient{}, RecordSize: 3070, @@ -40,8 +44,8 @@ func TestSendNotificationToURLEncodedSubscription(t *testing.T) { Topic: "test_topic", TTL: 0, Urgency: "low", - VAPIDPublicKey: "test-public", - VAPIDPrivateKey: "test-private", + VAPIDPublicKey: pub, + VAPIDPrivateKey: priv, }) if err != nil { t.Fatal(err) @@ -49,7 +53,7 @@ func TestSendNotificationToURLEncodedSubscription(t *testing.T) { if resp.StatusCode != 201 { t.Fatalf( - "Incorreect status code, expected=%d, got=%d", + "Incorrect status code, expected=%d, got=%d", resp.StatusCode, 201, ) @@ -57,13 +61,17 @@ func TestSendNotificationToURLEncodedSubscription(t *testing.T) { } func TestSendNotificationToStandardEncodedSubscription(t *testing.T) { + priv, _, err := GenerateVAPIDKeys() + if err != nil { + t.Fatal(err) + } resp, err := SendNotification([]byte("Test"), getStandardEncodedTestSubscription(), &Options{ HTTPClient: &testHTTPClient{}, Subscriber: "", Topic: "test_topic", TTL: 0, Urgency: "low", - VAPIDPrivateKey: "testKey", + VAPIDPrivateKey: priv, }) if err != nil { t.Fatal(err) From 0e35be50b80c48d97eec83111d441219748ce25c Mon Sep 17 00:00:00 2001 From: Shivaram Lingamneni Date: Thu, 12 Dec 2024 01:24:51 -0500 Subject: [PATCH 04/18] upgrade jwt to v5 --- go.mod | 2 +- go.sum | 2 ++ vapid.go | 2 +- vapid_test.go | 2 +- 4 files changed, 5 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 6b0604f..e534490 100644 --- a/go.mod +++ b/go.mod @@ -1,7 +1,7 @@ module github.com/SherClockHolmes/webpush-go require ( - github.com/golang-jwt/jwt v3.2.2+incompatible + github.com/golang-jwt/jwt/v5 v5.2.1 golang.org/x/crypto v0.9.0 ) diff --git a/go.sum b/go.sum index d9575c4..4a47b67 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= +github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= +github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= diff --git a/vapid.go b/vapid.go index fe2c580..6b5eb3e 100644 --- a/vapid.go +++ b/vapid.go @@ -10,7 +10,7 @@ import ( "net/url" "time" - "github.com/golang-jwt/jwt" + jwt "github.com/golang-jwt/jwt/v5" ) // GenerateVAPIDKeys will create a private and public VAPID key pair diff --git a/vapid_test.go b/vapid_test.go index be4cca8..3dc0b6c 100644 --- a/vapid_test.go +++ b/vapid_test.go @@ -7,7 +7,7 @@ import ( "testing" "time" - "github.com/golang-jwt/jwt" + jwt "github.com/golang-jwt/jwt/v5" ) func TestVAPID(t *testing.T) { From e758833f70d1966b9edbcf2934d423112437524c Mon Sep 17 00:00:00 2001 From: Shivaram Lingamneni Date: Wed, 25 Dec 2024 02:02:48 -0500 Subject: [PATCH 05/18] update end2end_test to use jwt v5 --- end2end_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/end2end_test.go b/end2end_test.go index 40cbc72..a6f5a97 100644 --- a/end2end_test.go +++ b/end2end_test.go @@ -18,7 +18,7 @@ import ( "strings" "testing" - "github.com/golang-jwt/jwt" + jwt "github.com/golang-jwt/jwt/v5" "golang.org/x/crypto/hkdf" ) From 21b34d8d84c99d381d114b13266b1befeb00e981 Mon Sep 17 00:00:00 2001 From: Shivaram Lingamneni Date: Sun, 29 Dec 2024 05:27:10 -0500 Subject: [PATCH 06/18] API-breaking changes to key handling * New key types for Keys and VAPIDKeys * VAPID key serializes to a JSON string of the PEM-encoded PKCS8 --- end2end_test.go | 51 ++++++------- legacy.go | 76 +++++++++++++++++++ legacy_test.go | 55 ++++++++++++++ vapid.go | 197 +++++++++++++++++++++++++++--------------------- vapid_test.go | 71 ++++++++++++----- webpush.go | 122 +++++++++++++++++++++--------- webpush_test.go | 77 ++++++++++--------- 7 files changed, 443 insertions(+), 206 deletions(-) create mode 100644 legacy.go create mode 100644 legacy_test.go diff --git a/end2end_test.go b/end2end_test.go index a6f5a97..ff134c5 100644 --- a/end2end_test.go +++ b/end2end_test.go @@ -15,6 +15,7 @@ import ( "io" "net/http" "net/http/httptest" + "strconv" "strings" "testing" @@ -26,9 +27,8 @@ func TestEnd2End(t *testing.T) { var ( // the data known to the application server (backend, which uses webpush-go) applicationServer struct { - publicVAPIDKey string - privateVAPIDKey string - subscription Subscription + vapidKeys *VAPIDKeys + subscription Subscription } // the data known to the user agent (browser) userAgent struct { @@ -48,17 +48,14 @@ func TestEnd2End(t *testing.T) { ) // a VAPID key pair for the application server, usually only generated once and reused - applicationServer.privateVAPIDKey, applicationServer.publicVAPIDKey, err = GenerateVAPIDKeys() + applicationServer.vapidKeys, err = GenerateVAPIDKeys() if err != nil { t.Fatalf("generating VAPID keys: %s", err) } // The application server needs to inform the user agent of the public VAPID key. // (We decode it first for ease of use.) - userAgent.publicVAPIDKey, err = decodeVAPIDPublicKey(applicationServer.publicVAPIDKey) - if err != nil { - t.Fatal(err) - } + userAgent.publicVAPIDKey = applicationServer.vapidKeys.privateKey.Public().(*ecdsa.PublicKey) // We need a mock push service for webpush-go to send notifications to. var mockPushService *httptest.Server @@ -93,6 +90,13 @@ func TestEnd2End(t *testing.T) { w.WriteHeader(http.StatusInternalServerError) return } + contentLength, err := strconv.Atoi(r.Header.Get("Content-Length")) + if err != nil { + t.Errorf("invalid content-length `%s`: %v", r.Header.Get("Content-Length"), err) + } + if len(body) != contentLength { + t.Errorf("body length %d did not match content-length header %d", len(body), contentLength) + } // store body for later decoding by user agent // (the push service doesn't have the key required for decryption) pushService.receivedNotifications = append(pushService.receivedNotifications, body) @@ -123,8 +127,8 @@ func TestEnd2End(t *testing.T) { pushService.applicationServerKey = userAgent.publicVAPIDKey userAgent.subscription = Subscription{ Keys: Keys{ - Auth: base64.StdEncoding.EncodeToString(userAgent.authSecret[:]), - P256dh: base64.StdEncoding.EncodeToString(ecdhPublicKey.Bytes()), + Auth: userAgent.authSecret, + P256dh: ecdhPublicKey, }, Endpoint: mockPushService.URL, } @@ -136,10 +140,9 @@ func TestEnd2End(t *testing.T) { // ...and the application server uses the subscription to send a push notification sentMessage := "this is our test push notification" resp, err := SendNotification([]byte(sentMessage), &applicationServer.subscription, &Options{ - HTTPClient: mockPushService.Client(), - VAPIDPublicKey: applicationServer.publicVAPIDKey, - VAPIDPrivateKey: applicationServer.privateVAPIDKey, - Subscriber: "test@example.com", + HTTPClient: mockPushService.Client(), + VAPIDKeys: applicationServer.vapidKeys, + Subscriber: "test@example.com", }) if err != nil { t.Fatalf("failed to send notification: %s", err) @@ -367,30 +370,18 @@ func decodeNotification(body []byte, authSecret [16]byte, userAgentKey *ecdsa.Pr return string(res), nil } -// test for the decoding helper function func Test_decodeVAPIDPublicKey(t *testing.T) { - privKeyB64, pubKeyB64, err := GenerateVAPIDKeys() + vapidKeys, err := GenerateVAPIDKeys() if err != nil { t.Fatalf("generating VAPID keys: %s", err) } - // as a baseline, decode using the library functions - privKeyBytes, err := decodeVapidKey(privKeyB64) - if err != nil { - t.Fatalf("decoding private key: %s", err) - } - privKey, err := generateVAPIDHeaderKeys(privKeyBytes) - if err != nil { - t.Fatalf("converting private key: %s", err) - } - wantPubKey := &privKey.PublicKey - // now decode using our test helper and compare the results - gotPubKey, err := decodeVAPIDPublicKey(pubKeyB64) + gotPubKey, err := decodeVAPIDPublicKey(vapidKeys.publicKey) if err != nil { t.Fatalf("decoding public key") } - if !gotPubKey.Equal(wantPubKey) { - t.Errorf("result differs:\ngot: %v\nwant: %v", gotPubKey, wantPubKey) + if !gotPubKey.Equal(vapidKeys.privateKey.Public()) { + t.Errorf("result differs:\ngot: %v\nwant: %v", gotPubKey, vapidKeys.privateKey.Public()) } } diff --git a/legacy.go b/legacy.go new file mode 100644 index 0000000..b151da9 --- /dev/null +++ b/legacy.go @@ -0,0 +1,76 @@ +package webpush + +import ( + "crypto/ecdh" + "crypto/ecdsa" + "crypto/elliptic" + "encoding/base64" + "fmt" + "math/big" +) + +// ecdhPublicKeyToECDSA converts an ECDH key to an ECDSA key. +// This is deprecated as per https://github.com/golang/go/issues/63963 +// but we need to do it in order to parse the legacy private key format. +func ecdhPublicKeyToECDSA(key *ecdh.PublicKey) (*ecdsa.PublicKey, error) { + rawKey := key.Bytes() + switch key.Curve() { + case ecdh.P256(): + return &ecdsa.PublicKey{ + Curve: elliptic.P256(), + X: big.NewInt(0).SetBytes(rawKey[1:33]), + Y: big.NewInt(0).SetBytes(rawKey[33:]), + }, nil + case ecdh.P384(): + return &ecdsa.PublicKey{ + Curve: elliptic.P384(), + X: big.NewInt(0).SetBytes(rawKey[1:49]), + Y: big.NewInt(0).SetBytes(rawKey[49:]), + }, nil + case ecdh.P521(): + return &ecdsa.PublicKey{ + Curve: elliptic.P521(), + X: big.NewInt(0).SetBytes(rawKey[1:67]), + Y: big.NewInt(0).SetBytes(rawKey[67:]), + }, nil + default: + return nil, fmt.Errorf("cannot convert non-NIST *ecdh.PublicKey to *ecdsa.PublicKey") + } +} + +func ecdhPrivateKeyToECDSA(key *ecdh.PrivateKey) (*ecdsa.PrivateKey, error) { + // see https://github.com/golang/go/issues/63963 + pubKey, err := ecdhPublicKeyToECDSA(key.PublicKey()) + if err != nil { + return nil, fmt.Errorf("converting PublicKey part of *ecdh.PrivateKey: %w", err) + } + return &ecdsa.PrivateKey{ + PublicKey: *pubKey, + D: big.NewInt(0).SetBytes(key.Bytes()), + }, nil +} + +// DecodeLegacyVAPIDPrivateKey decodes the legacy string private key format +// returned by GenerateVAPIDKeys in v1. +func DecodeLegacyVAPIDPrivateKey(key string) (*VAPIDKeys, error) { + bytes, err := decodeSubscriptionKey(key) + if err != nil { + return nil, err + } + + ecdhPrivKey, err := ecdh.P256().NewPrivateKey(bytes) + if err != nil { + return nil, err + } + + ecdsaPrivKey, err := ecdhPrivateKeyToECDSA(ecdhPrivKey) + if err != nil { + return nil, err + } + + publicKey := base64.RawURLEncoding.EncodeToString(ecdhPrivKey.PublicKey().Bytes()) + return &VAPIDKeys{ + privateKey: ecdsaPrivKey, + publicKey: publicKey, + }, nil +} diff --git a/legacy_test.go b/legacy_test.go new file mode 100644 index 0000000..9d7987b --- /dev/null +++ b/legacy_test.go @@ -0,0 +1,55 @@ +package webpush + +import ( + "encoding/json" + "testing" +) + +const ( + legacyKeysJSON = `{"auth":"1F2Auk0iTJKXjJyiPlMu+w==","p256dh":"BJx6rbJEVu/Juf1xNEk6jO3pTxkyNFGqK1r/zw/iiaEnATH736mYYUSDLFRBsSaIK47vLsVmI+cNraliHyl/8WM="}` + + legacyVAPIDPublicKey = `BEkDdNnpEcD8M4mRGOFJWTDJ4GkDI5Xs3vpIOrAaBZKRCVv6V3sB3CFujTFiD6DHda7W8pCyChJDU205otrbCAw` + + legacyVAPIDPrivateKey = `F0RGqNXLeWLINzn7qIcLsF9lSbRSWgjqUVaoWB6zUqY` + + legacyVAPIDKeyAsJSON = `"-----BEGIN PRIVATE KEY-----\nMIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgF0RGqNXLeWLINzn7\nqIcLsF9lSbRSWgjqUVaoWB6zUqahRANCAARJA3TZ6RHA/DOJkRjhSVkwyeBpAyOV\n7N76SDqwGgWSkQlb+ld7Adwhbo0xYg+gx3Wu1vKQsgoSQ1NtOaLa2wgM\n-----END PRIVATE KEY-----\n"` +) + +func TestLegacySubscriptionKeypair(t *testing.T) { + var keys Keys + err := json.Unmarshal([]byte(legacyKeysJSON), &keys) + if err != nil { + t.Fatal(err) + } + var emptyKeys Keys + if keys.Auth == emptyKeys.Auth { + t.Fatal("failed to deserialize auth key") + } + if keys.P256dh == emptyKeys.P256dh { + t.Fatal("failed to deserialize p256dh key") + } +} + +func TestLegacyVAPIDParsing(t *testing.T) { + // test that we can parse the legacy VAPID private key format (raw bytes + // of the private key as b64) and we get the same keys as the JSON format + vapidKeys, err := DecodeLegacyVAPIDPrivateKey(legacyVAPIDPrivateKey) + if err != nil { + t.Fatal(err) + } + + if vapidKeys.publicKey != legacyVAPIDPublicKey { + t.Fatal("decoded legacy VAPID private key, but did not recover true public key") + } + + vapidKeysFromJSON := new(VAPIDKeys) + if err := json.Unmarshal([]byte(legacyVAPIDKeyAsJSON), vapidKeysFromJSON); err != nil { + t.Fatal(err) + } + if !vapidKeys.privateKey.Equal(vapidKeysFromJSON.privateKey) { + t.Fatal("decoded legacy VAPID private key, but did not recover true private key") + } + if vapidKeys.publicKey != vapidKeysFromJSON.publicKey { + t.Fatal("unexpected private/public key mismatch") + } +} diff --git a/vapid.go b/vapid.go index 3204c97..dae676b 100644 --- a/vapid.go +++ b/vapid.go @@ -1,13 +1,14 @@ package webpush import ( - "crypto/ecdh" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" + "crypto/x509" "encoding/base64" + "encoding/json" + "encoding/pem" "fmt" - "math/big" "net/url" "strings" "time" @@ -15,80 +16,134 @@ import ( jwt "github.com/golang-jwt/jwt/v5" ) -// GenerateVAPIDKeys will create a private and public VAPID key pair -func GenerateVAPIDKeys() (privateKey, publicKey string, err error) { - // Get the private key from the P256 curve - curve := ecdh.P256() +// VAPIDKeys is a public-private keypair for use in VAPID. +// It marshals to a JSON string containing the PEM of the PKCS8 +// of the private key. +type VAPIDKeys struct { + privateKey *ecdsa.PrivateKey + publicKey string // raw bytes encoding in urlsafe base64, as per RFC +} + +// PublicKeyString returns the base64url-encoded uncompressed public key of the keypair, +// as defined in RFC8292. +func (v *VAPIDKeys) PublicKeyString() string { + return v.publicKey +} + +// PrivateKey returns the private key of the keypair. +func (v *VAPIDKeys) PrivateKey() *ecdsa.PrivateKey { + return v.privateKey +} + +// Equal compares two VAPIDKeys for equality. +func (v *VAPIDKeys) Equal(o *VAPIDKeys) bool { + return v.privateKey.Equal(o.privateKey) +} + +var _ json.Marshaler = (*VAPIDKeys)(nil) +var _ json.Unmarshaler = (*VAPIDKeys)(nil) - private, err := curve.GenerateKey(rand.Reader) +// MarshalJSON implements json.Marshaler, allowing serialization to JSON. +func (v *VAPIDKeys) MarshalJSON() ([]byte, error) { + pkcs8bytes, err := x509.MarshalPKCS8PrivateKey(v.privateKey) if err != nil { - return + return nil, err } - - // Convert to base64 - publicKey = base64.RawURLEncoding.EncodeToString(private.PublicKey().Bytes()) - privateKey = base64.RawURLEncoding.EncodeToString(private.Bytes()) - return + pemBlock := pem.Block{ + Type: "PRIVATE KEY", + Bytes: pkcs8bytes, + } + pemBytes := pem.EncodeToMemory(&pemBlock) + if pemBytes == nil { + return nil, fmt.Errorf("could not encode VAPID keys as PEM") + } + return json.Marshal(string(pemBytes)) } -// Generates the ECDSA public and private keys for the JWT encryption -func generateVAPIDHeaderKeys(privateKey []byte) (*ecdsa.PrivateKey, error) { - key, err := ecdh.P256().NewPrivateKey(privateKey) +// MarshalJSON implements json.Unmarshaler, allowing deserialization from JSON. +func (v *VAPIDKeys) UnmarshalJSON(b []byte) error { + var pemKey string + if err := json.Unmarshal(b, &pemKey); err != nil { + return err + } + pemBlock, _ := pem.Decode([]byte(pemKey)) + if pemBlock == nil { + return fmt.Errorf("could not decode PEM block with VAPID keys") + } + privKey, err := x509.ParsePKCS8PrivateKey(pemBlock.Bytes) if err != nil { - return nil, fmt.Errorf("validating private key: %w", err) + return err + } + privateKey, ok := privKey.(*ecdsa.PrivateKey) + if !ok { + return fmt.Errorf("Invalid type of private key %T", privateKey) } - converted, err := ecdhPrivateKeyToECDSA(key) + if privateKey.Curve != elliptic.P256() { + return fmt.Errorf("Invalid curve for private key %v", privateKey.Curve) + } + publicKeyStr, err := makePublicKeyString(privateKey) if err != nil { - return nil, fmt.Errorf("converting private key to crypto/ecdsa: %w", err) + return err // should not be possible since we confirmed P256 already } - return converted, nil + + // success + v.privateKey = privateKey + v.publicKey = publicKeyStr + return nil } -func ecdhPublicKeyToECDSA(key *ecdh.PublicKey) (*ecdsa.PublicKey, error) { - // see https://github.com/golang/go/issues/63963 - rawKey := key.Bytes() - switch key.Curve() { - case ecdh.P256(): - return &ecdsa.PublicKey{ - Curve: elliptic.P256(), - X: big.NewInt(0).SetBytes(rawKey[1:33]), - Y: big.NewInt(0).SetBytes(rawKey[33:]), - }, nil - case ecdh.P384(): - return &ecdsa.PublicKey{ - Curve: elliptic.P384(), - X: big.NewInt(0).SetBytes(rawKey[1:49]), - Y: big.NewInt(0).SetBytes(rawKey[49:]), - }, nil - case ecdh.P521(): - return &ecdsa.PublicKey{ - Curve: elliptic.P521(), - X: big.NewInt(0).SetBytes(rawKey[1:67]), - Y: big.NewInt(0).SetBytes(rawKey[67:]), - }, nil - default: - return nil, fmt.Errorf("cannot convert non-NIST *ecdh.PublicKey to *ecdsa.PublicKey") +// GenerateVAPIDKeys generates a VAPID keypair (an ECDSA keypair on +// the P-256 curve). +func GenerateVAPIDKeys() (result *VAPIDKeys, err error) { + private, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return + } + + pubKeyECDH, err := private.PublicKey.ECDH() + if err != nil { + return } + publicKey := base64.RawURLEncoding.EncodeToString(pubKeyECDH.Bytes()) + + return &VAPIDKeys{ + privateKey: private, + publicKey: publicKey, + }, nil } -func ecdhPrivateKeyToECDSA(key *ecdh.PrivateKey) (*ecdsa.PrivateKey, error) { - // see https://github.com/golang/go/issues/63963 - pubKey, err := ecdhPublicKeyToECDSA(key.PublicKey()) +// ECDSAToVAPIDKeys wraps an existing ecdsa.PrivateKey in VAPIDKeys for use in +// VAPID header signing. +func ECDSAToVAPIDKeys(privKey *ecdsa.PrivateKey) (result *VAPIDKeys, err error) { + if privKey.Curve != elliptic.P256() { + return nil, fmt.Errorf("Invalid curve for private key %v", privKey.Curve) + } + publicKeyString, err := makePublicKeyString(privKey) if err != nil { - return nil, fmt.Errorf("converting PublicKey part of *ecdh.PrivateKey: %w", err) + return nil, err } - return &ecdsa.PrivateKey{ - PublicKey: *pubKey, - D: big.NewInt(0).SetBytes(key.Bytes()), + return &VAPIDKeys{ + privateKey: privKey, + publicKey: publicKeyString, }, nil } +func makePublicKeyString(privKey *ecdsa.PrivateKey) (result string, err error) { + // to get the raw bytes we have to convert the public key to *ecdh.PublicKey + // this type assertion (from the crypto.PublicKey returned by (*ecdsa.PrivateKey).Public() + // to *ecdsa.PublicKey) cannot fail: + publicKey, err := privKey.Public().(*ecdsa.PublicKey).ECDH() + if err != nil { + return // should not be possible if we confirmed P256 already + } + return base64.RawURLEncoding.EncodeToString(publicKey.Bytes()), nil +} + // getVAPIDAuthorizationHeader func getVAPIDAuthorizationHeader( - endpoint, - subscriber, - vapidPublicKey, - vapidPrivateKey string, + endpoint string, + subscriber string, + vapidKeys *VAPIDKeys, expiration time.Time, ) (string, error) { // Create the JWT token @@ -108,25 +163,8 @@ func getVAPIDAuthorizationHeader( "sub": subscriber, }) - // Decode the VAPID private key - decodedVapidPrivateKey, err := decodeVapidKey(vapidPrivateKey) - if err != nil { - return "", err - } - - privKey, err := generateVAPIDHeaderKeys(decodedVapidPrivateKey) - if err != nil { - return "", fmt.Errorf("generating VAPID header keys: %w", err) - } - // Sign token with private key - jwtString, err := token.SignedString(privKey) - if err != nil { - return "", err - } - - // Decode the VAPID public key - pubKey, err := decodeVapidKey(vapidPublicKey) + jwtString, err := token.SignedString(vapidKeys.privateKey) if err != nil { return "", err } @@ -134,17 +172,6 @@ func getVAPIDAuthorizationHeader( return fmt.Sprintf( "vapid t=%s, k=%s", jwtString, - base64.RawURLEncoding.EncodeToString(pubKey), + vapidKeys.publicKey, ), nil } - -// Need to decode the vapid private key in multiple base64 formats -// Solution from: https://github.com/SherClockHolmes/webpush-go/issues/29 -func decodeVapidKey(key string) ([]byte, error) { - bytes, err := base64.URLEncoding.DecodeString(key) - if err == nil { - return bytes, nil - } - - return base64.RawURLEncoding.DecodeString(key) -} diff --git a/vapid_test.go b/vapid_test.go index a80564a..9f68a14 100644 --- a/vapid_test.go +++ b/vapid_test.go @@ -4,7 +4,7 @@ import ( "crypto/ecdsa" "crypto/elliptic" "crypto/rand" - "encoding/base64" + "encoding/json" "fmt" "strings" "testing" @@ -18,7 +18,7 @@ func TestVAPID(t *testing.T) { sub := "test@test.com" // Generate vapid keys - vapidPrivateKey, vapidPublicKey, err := GenerateVAPIDKeys() + vapidKeys, err := GenerateVAPIDKeys() if err != nil { t.Fatal(err) } @@ -27,8 +27,7 @@ func TestVAPID(t *testing.T) { vapidAuthHeader, err := getVAPIDAuthorizationHeader( s.Endpoint, sub, - vapidPublicKey, - vapidPrivateKey, + vapidKeys, time.Now().Add(time.Hour*12), ) if err != nil { @@ -44,17 +43,7 @@ func TestVAPID(t *testing.T) { } // To decode the token it needs the VAPID public key - b64 := base64.RawURLEncoding - decodedVapidPrivateKey, err := b64.DecodeString(vapidPrivateKey) - if err != nil { - t.Fatalf("Could not decode VAPID private key: %s", err) - } - - privKey, err := generateVAPIDHeaderKeys(decodedVapidPrivateKey) - if err != nil { - t.Fatalf("Could not parse VAPID private key: %s", err) - } - return privKey.Public(), nil + return vapidKeys.privateKey.Public(), nil }) // Check the claims on the token @@ -78,17 +67,27 @@ func TestVAPID(t *testing.T) { } func TestVAPIDKeys(t *testing.T) { - privateKey, publicKey, err := GenerateVAPIDKeys() + vapidKeys, err := GenerateVAPIDKeys() + if err != nil { + t.Fatal(err) + } + + j, err := json.Marshal(vapidKeys) if err != nil { t.Fatal(err) } - if len(privateKey) != 43 { - t.Fatal("Generated incorrect VAPID private key") + vapidKeys2 := new(VAPIDKeys) + if err := json.Unmarshal(j, vapidKeys2); err != nil { + t.Fatal(err) + } + + if !vapidKeys.privateKey.Equal(vapidKeys2.privateKey) { + t.Fatalf("could not round-trip private key") } - if len(publicKey) != 87 { - t.Fatal("Generated incorrect VAPID public key") + if vapidKeys.publicKey != vapidKeys2.publicKey { + t.Fatalf("could not round-trip public key") } } @@ -187,3 +186,35 @@ func Test_ecdhPrivateKeyToECDSA(t *testing.T) { }) } } + +func TestVAPIDKeyFromECDSA(t *testing.T) { + v, err := GenerateVAPIDKeys() + if err != nil { + t.Fatal(err) + } + privKey := v.PrivateKey() + v2, err := ECDSAToVAPIDKeys(privKey) + if err != nil { + t.Fatal(err) + } + if !v.Equal(v2) { + t.Fatal("ECDSAToVAPIDKeys failed round-trip") + } +} + +func BenchmarkVAPIDSigning(b *testing.B) { + vapidKeys, err := GenerateVAPIDKeys() + if err != nil { + b.Fatal(err) + } + expiration := time.Now().Add(24 * time.Hour) + b.ResetTimer() + for i := 0; i < b.N; i++ { + getVAPIDAuthorizationHeader( + "https://test.push.service/v2/AOWJIDuOMDSo6uNnRXYNsw", + "https://application.server", + vapidKeys, + expiration, + ) + } +} diff --git a/webpush.go b/webpush.go index a2d9768..e691598 100644 --- a/webpush.go +++ b/webpush.go @@ -10,6 +10,7 @@ import ( "crypto/sha256" "encoding/base64" "encoding/binary" + "encoding/json" "errors" "fmt" "io" @@ -23,7 +24,11 @@ import ( const MaxRecordSize uint32 = 4096 -var ErrMaxPadExceeded = errors.New("payload has exceeded the maximum length") +var ( + ErrMaxPadExceeded = errors.New("payload has exceeded the maximum length") + + invalidAuthKeyLength = errors.New("invalid auth key length (must be 16)") +) // saltFunc generates a salt of 16 bytes var saltFunc = func() ([]byte, error) { @@ -47,19 +52,86 @@ type Options struct { RecordSize uint32 // Limit the record size Subscriber string // Sub in VAPID JWT token Topic string // Set the Topic header to collapse a pending messages (Optional) - TTL int // Set the TTL on the endpoint POST request + TTL int // Set the TTL on the endpoint POST request, in seconds Urgency Urgency // Set the Urgency header to change a message priority (Optional) - VAPIDPublicKey string // VAPID public key, passed in VAPID Authorization header - VAPIDPrivateKey string // VAPID private key, used to sign VAPID JWT token + VAPIDKeys *VAPIDKeys // VAPID public-private keypair to generate the VAPID Authorization header VapidExpiration time.Time // optional expiration for VAPID JWT token (defaults to now + 12 hours) } -// Keys are the base64 encoded values from PushSubscription.getKey() +// Keys represents a subscription's keys (its ECDH public key on the P-256 curve +// and its 16-byte authentication secret). type Keys struct { + Auth [16]byte + P256dh *ecdh.PublicKey +} + +// Equal compares two Keys for equality. +func (k *Keys) Equal(o Keys) bool { + return k.Auth == o.Auth && k.P256dh.Equal(o.P256dh) +} + +var _ json.Marshaler = (*Keys)(nil) +var _ json.Unmarshaler = (*Keys)(nil) + +type marshaledKeys struct { Auth string `json:"auth"` P256dh string `json:"p256dh"` } +// MarshalJSON implements json.Marshaler, allowing serialization to JSON. +func (k *Keys) MarshalJSON() ([]byte, error) { + m := marshaledKeys{ + Auth: base64.RawStdEncoding.EncodeToString(k.Auth[:]), + P256dh: base64.RawStdEncoding.EncodeToString(k.P256dh.Bytes()), + } + return json.Marshal(&m) +} + +// MarshalJSON implements json.Unmarshaler, allowing deserialization from JSON. +func (k *Keys) UnmarshalJSON(b []byte) (err error) { + var m marshaledKeys + if err := json.Unmarshal(b, &m); err != nil { + return err + } + authBytes, err := decodeSubscriptionKey(m.Auth) + if err != nil { + return err + } + if len(authBytes) != 16 { + return fmt.Errorf("invalid auth bytes length %d (must be 16)", len(authBytes)) + } + copy(k.Auth[:], authBytes) + rawDHKey, err := decodeSubscriptionKey(m.P256dh) + if err != nil { + return err + } + k.P256dh, err = ecdh.P256().NewPublicKey(rawDHKey) + return err +} + +// DecodeSubscriptionKeys decodes and validates a base64-encoded pair of subscription keys +// (the authentication secret and ECDH public key). +func DecodeSubscriptionKeys(auth, p256dh string) (keys Keys, err error) { + authBytes, err := decodeSubscriptionKey(auth) + if err != nil { + return + } + if len(authBytes) != 16 { + err = invalidAuthKeyLength + return + } + copy(keys.Auth[:], authBytes) + dhBytes, err := decodeSubscriptionKey(p256dh) + if err != nil { + return + } + keys.P256dh, err = ecdh.P256().NewPublicKey(dhBytes) + if err != nil { + return + } + return +} + // Subscription represents a PushSubscription object from the Push API type Subscription struct { Endpoint string `json:"endpoint"` @@ -75,22 +147,6 @@ func SendNotification(message []byte, s *Subscription, options *Options) (*http. // Message Encryption for Web Push, and VAPID protocols. // FOR MORE INFORMATION SEE RFC8291: https://datatracker.ietf.org/doc/rfc8291 func SendNotificationWithContext(ctx context.Context, message []byte, s *Subscription, options *Options) (*http.Response, error) { - // Authentication secret (auth_secret) - authSecret, err := decodeSubscriptionKey(s.Keys.Auth) - if err != nil { - return nil, fmt.Errorf("decoding keys.auth: %w", err) - } - - // dh (Diffie Hellman) - dh, err := decodeSubscriptionKey(s.Keys.P256dh) - if err != nil { - return nil, fmt.Errorf("decoding keys.p256dh: %w", err) - } - userAgentPublicKey, err := ecdh.P256().NewPublicKey(dh) - if err != nil { - return nil, fmt.Errorf("validating keys.p256dh: %w", err) - } - // Generate 16 byte salt salt, err := saltFunc() if err != nil { @@ -108,7 +164,7 @@ func SendNotificationWithContext(ctx context.Context, message []byte, s *Subscri localPublicKey := localPrivateKey.PublicKey() // Combine application keys with receiver's EC public key to derive ECDH shared secret - sharedECDHSecret, err := localPrivateKey.ECDH(userAgentPublicKey) + sharedECDHSecret, err := localPrivateKey.ECDH(s.Keys.P256dh) if err != nil { return nil, fmt.Errorf("deriving shared secret: %w", err) } @@ -117,10 +173,10 @@ func SendNotificationWithContext(ctx context.Context, message []byte, s *Subscri // ikm prkInfoBuf := bytes.NewBuffer([]byte("WebPush: info\x00")) - prkInfoBuf.Write(dh) + prkInfoBuf.Write(s.Keys.P256dh.Bytes()) prkInfoBuf.Write(localPublicKey.Bytes()) - prkHKDF := hkdf.New(hash, sharedECDHSecret, authSecret, prkInfoBuf.Bytes()) + prkHKDF := hkdf.New(hash, sharedECDHSecret, s.Keys.Auth[:], prkInfoBuf.Bytes()) ikm, err := getHKDFKey(prkHKDF, 32) if err != nil { return nil, err @@ -218,8 +274,7 @@ func SendNotificationWithContext(ctx context.Context, message []byte, s *Subscri vapidAuthHeader, err := getVAPIDAuthorizationHeader( s.Endpoint, options.Subscriber, - options.VAPIDPublicKey, - options.VAPIDPrivateKey, + options.VAPIDKeys, expiration, ) if err != nil { @@ -240,20 +295,13 @@ func SendNotificationWithContext(ctx context.Context, message []byte, s *Subscri } // decodeSubscriptionKey decodes a base64 subscription key. -// if necessary, add "=" padding to the key for URL decode func decodeSubscriptionKey(key string) ([]byte, error) { - // "=" padding - buf := bytes.NewBufferString(key) - if rem := len(key) % 4; rem != 0 { - buf.WriteString(strings.Repeat("=", 4-rem)) - } + key = strings.TrimRight(key, "=") - bytes, err := base64.StdEncoding.DecodeString(buf.String()) - if err == nil { - return bytes, nil + if strings.IndexByte(key, '+') != -1 || strings.IndexByte(key, '/') != -1 { + return base64.RawStdEncoding.DecodeString(key) } - - return base64.URLEncoding.DecodeString(buf.String()) + return base64.RawURLEncoding.DecodeString(key) } // Returns a key of length "length" given an hkdf function diff --git a/webpush_test.go b/webpush_test.go index d1f74c5..f9d36c8 100644 --- a/webpush_test.go +++ b/webpush_test.go @@ -1,6 +1,7 @@ package webpush import ( + "encoding/json" "net/http" "strings" "testing" @@ -13,39 +14,48 @@ func (*testHTTPClient) Do(*http.Request) (*http.Response, error) { } func getURLEncodedTestSubscription() *Subscription { - return &Subscription{ - Endpoint: "https://updates.push.services.mozilla.com/wpush/v2/gAAAAA", - Keys: Keys{ - P256dh: "BNNL5ZaTfK81qhXOx23-wewhigUeFb632jN6LvRWCFH1ubQr77FE_9qV1FuojuRmHP42zmf34rXgW80OvUVDgTk", - Auth: "zqbxT6JKstKSY9JKibZLSQ", - }, + subJson := `{ + "endpoint": "https://updates.push.services.mozilla.com/wpush/v2/gAAAAA", + "keys": { + "p256dh": "BNNL5ZaTfK81qhXOx23-wewhigUeFb632jN6LvRWCFH1ubQr77FE_9qV1FuojuRmHP42zmf34rXgW80OvUVDgTk", + "auth": "zqbxT6JKstKSY9JKibZLSQ" + } + }` + sub := new(Subscription) + if err := json.Unmarshal([]byte(subJson), sub); err != nil { + panic(err) } + return sub } func getStandardEncodedTestSubscription() *Subscription { - return &Subscription{ - Endpoint: "https://updates.push.services.mozilla.com/wpush/v2/gAAAAA", - Keys: Keys{ - P256dh: "BNNL5ZaTfK81qhXOx23+wewhigUeFb632jN6LvRWCFH1ubQr77FE/9qV1FuojuRmHP42zmf34rXgW80OvUVDgTk=", - Auth: "zqbxT6JKstKSY9JKibZLSQ==", - }, + subJson := `{ + "endpoint": "https://updates.push.services.mozilla.com/wpush/v2/gAAAAA", + "keys": { + "p256dh": "BNNL5ZaTfK81qhXOx23+wewhigUeFb632jN6LvRWCFH1ubQr77FE/9qV1FuojuRmHP42zmf34rXgW80OvUVDgTk=", + "auth": "zqbxT6JKstKSY9JKibZLSQ==" + } + }` + sub := new(Subscription) + if err := json.Unmarshal([]byte(subJson), sub); err != nil { + panic(err) } + return sub } func TestSendNotificationToURLEncodedSubscription(t *testing.T) { - priv, pub, err := GenerateVAPIDKeys() + vapidKeys, err := GenerateVAPIDKeys() if err != nil { t.Fatal(err) } resp, err := SendNotification([]byte("Test"), getURLEncodedTestSubscription(), &Options{ - HTTPClient: &testHTTPClient{}, - RecordSize: 3070, - Subscriber: "", - Topic: "test_topic", - TTL: 0, - Urgency: "low", - VAPIDPublicKey: pub, - VAPIDPrivateKey: priv, + HTTPClient: &testHTTPClient{}, + RecordSize: 3070, + Subscriber: "", + Topic: "test_topic", + TTL: 0, + Urgency: "low", + VAPIDKeys: vapidKeys, }) if err != nil { t.Fatal(err) @@ -61,17 +71,17 @@ func TestSendNotificationToURLEncodedSubscription(t *testing.T) { } func TestSendNotificationToStandardEncodedSubscription(t *testing.T) { - priv, _, err := GenerateVAPIDKeys() + vapidKeys, err := GenerateVAPIDKeys() if err != nil { t.Fatal(err) } resp, err := SendNotification([]byte("Test"), getStandardEncodedTestSubscription(), &Options{ - HTTPClient: &testHTTPClient{}, - Subscriber: "", - Topic: "test_topic", - TTL: 0, - Urgency: "low", - VAPIDPrivateKey: priv, + HTTPClient: &testHTTPClient{}, + Subscriber: "", + Topic: "test_topic", + TTL: 0, + Urgency: "low", + VAPIDKeys: vapidKeys, }) if err != nil { t.Fatal(err) @@ -88,12 +98,11 @@ func TestSendNotificationToStandardEncodedSubscription(t *testing.T) { func TestSendTooLargeNotification(t *testing.T) { _, err := SendNotification([]byte(strings.Repeat("Test", int(MaxRecordSize))), getStandardEncodedTestSubscription(), &Options{ - HTTPClient: &testHTTPClient{}, - Subscriber: "", - Topic: "test_topic", - TTL: 0, - Urgency: "low", - VAPIDPrivateKey: "testKey", + HTTPClient: &testHTTPClient{}, + Subscriber: "", + Topic: "test_topic", + TTL: 0, + Urgency: "low", }) if err == nil { t.Fatalf("Error is nil, expected=%s", ErrMaxPadExceeded) From e8d3e209a8424adff454f1368514faeeb3b5f2cd Mon Sep 17 00:00:00 2001 From: Shivaram Lingamneni Date: Sun, 29 Dec 2024 05:27:10 -0500 Subject: [PATCH 07/18] delete Content-Length header The passed value was wrong, but net/http ignores it anyway (it does a type assertion for bytes.Buffer and uses Len() on success). --- webpush.go | 1 - 1 file changed, 1 deletion(-) diff --git a/webpush.go b/webpush.go index e691598..ff2a92e 100644 --- a/webpush.go +++ b/webpush.go @@ -252,7 +252,6 @@ func SendNotificationWithContext(ctx context.Context, message []byte, s *Subscri } req.Header.Set("Content-Encoding", "aes128gcm") - req.Header.Set("Content-Length", strconv.Itoa(len(ciphertext))) req.Header.Set("Content-Type", "application/octet-stream") req.Header.Set("TTL", strconv.Itoa(options.TTL)) From a5a4ca2c4563c4a3ac6e37f4a4c92ca21dda3af8 Mon Sep 17 00:00:00 2001 From: Shivaram Lingamneni Date: Sun, 29 Dec 2024 05:27:10 -0500 Subject: [PATCH 08/18] add Makefile --- .check-gofmt.sh | 13 +++++++++++++ Makefile | 6 ++++++ 2 files changed, 19 insertions(+) create mode 100755 .check-gofmt.sh create mode 100644 Makefile diff --git a/.check-gofmt.sh b/.check-gofmt.sh new file mode 100755 index 0000000..48a1aa3 --- /dev/null +++ b/.check-gofmt.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +SOURCES="." + +if [ "$1" = "--fix" ]; then + exec gofmt -s -w $SOURCES +fi + +if [ -n "$(gofmt -s -l $SOURCES)" ]; then + echo "Go code is not formatted correctly with \`gofmt -s\`:" + gofmt -s -d $SOURCES + exit 1 +fi diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..7b72a66 --- /dev/null +++ b/Makefile @@ -0,0 +1,6 @@ +.PHONY: test + +test: + go test . + go vet . + ./.check-gofmt.sh From b19d780e7876e859b6a29d42c8662724c3dbfccd Mon Sep 17 00:00:00 2001 From: Shivaram Lingamneni Date: Sun, 29 Dec 2024 05:27:10 -0500 Subject: [PATCH 09/18] apply microoptimizations for VAPID See https://github.com/SherClockHolmes/webpush-go/pull/46 --- vapid.go | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/vapid.go b/vapid.go index dae676b..ddf1b98 100644 --- a/vapid.go +++ b/vapid.go @@ -153,12 +153,12 @@ func getVAPIDAuthorizationHeader( } // Unless subscriber is an HTTPS URL, assume an e-mail address - if !strings.HasPrefix(subscriber, "https:") { - subscriber = fmt.Sprintf("mailto:%s", subscriber) + if !strings.HasPrefix(subscriber, "https:") && !strings.HasPrefix(subscriber, "mailto:") { + subscriber = "mailto:" + subscriber } token := jwt.NewWithClaims(jwt.SigningMethodES256, jwt.MapClaims{ - "aud": fmt.Sprintf("%s://%s", subURL.Scheme, subURL.Host), + "aud": subURL.Scheme + "://" + subURL.Host, "exp": expiration.Unix(), "sub": subscriber, }) @@ -169,9 +169,5 @@ func getVAPIDAuthorizationHeader( return "", err } - return fmt.Sprintf( - "vapid t=%s, k=%s", - jwtString, - vapidKeys.publicKey, - ), nil + return "vapid t=" + jwtString + ", k=" + vapidKeys.publicKey, nil } From 5892574774a66360711c95140c980a4aa3f775fa Mon Sep 17 00:00:00 2001 From: Shivaram Lingamneni Date: Sun, 29 Dec 2024 05:27:10 -0500 Subject: [PATCH 10/18] add shared default HTTP client --- webpush.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/webpush.go b/webpush.go index ff2a92e..87581d2 100644 --- a/webpush.go +++ b/webpush.go @@ -28,6 +28,8 @@ var ( ErrMaxPadExceeded = errors.New("payload has exceeded the maximum length") invalidAuthKeyLength = errors.New("invalid auth key length (must be 16)") + + defaultHTTPClient = &http.Client{} ) // saltFunc generates a salt of 16 bytes @@ -287,7 +289,7 @@ func SendNotificationWithContext(ctx context.Context, message []byte, s *Subscri if options.HTTPClient != nil { client = options.HTTPClient } else { - client = &http.Client{} + client = defaultHTTPClient } return client.Do(req) From 96e52ef9b896f22d6f81c7db7c6a2ec6d6267dad Mon Sep 17 00:00:00 2001 From: Shivaram Lingamneni Date: Sun, 29 Dec 2024 05:27:10 -0500 Subject: [PATCH 11/18] add a basic benchmark --- webpush_test.go | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/webpush_test.go b/webpush_test.go index f9d36c8..436a495 100644 --- a/webpush_test.go +++ b/webpush_test.go @@ -1,6 +1,7 @@ package webpush import ( + "context" "encoding/json" "net/http" "strings" @@ -108,3 +109,28 @@ func TestSendTooLargeNotification(t *testing.T) { t.Fatalf("Error is nil, expected=%s", ErrMaxPadExceeded) } } + +func BenchmarkWebPush(b *testing.B) { + vapidKeys, err := GenerateVAPIDKeys() + if err != nil { + b.Fatal(err) + } + ctx := context.Background() + message := []byte("@time=2024-12-26T19:36:21.923Z;account=shivaram;msgid=56g9v3b92q6q4wtq43uhyqzegw :shivaram!~u@kca7nfgniet7q.irc PRIVMSG #redacted :[redacted message contents]") + sub := getStandardEncodedTestSubscription() + options := Options{ + HTTPClient: &testHTTPClient{}, + RecordSize: 2048, + Subscriber: "https://example.com", + TTL: 60 * 60 * 24, + Urgency: UrgencyHigh, + VAPIDKeys: vapidKeys, + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + if _, err := SendNotificationWithContext(ctx, message, sub, &options); err != nil { + b.Fatal(err) + } + } +} From f78537821698455c306cb285e60f31d40e5299be Mon Sep 17 00:00:00 2001 From: Shivaram Lingamneni Date: Sun, 29 Dec 2024 05:27:10 -0500 Subject: [PATCH 12/18] microoptimizations to message encryption --- vapid.go | 4 ++ webpush.go | 167 +++++++++++++++++++++++------------------------- webpush_test.go | 2 +- 3 files changed, 86 insertions(+), 87 deletions(-) diff --git a/vapid.go b/vapid.go index ddf1b98..f4b0b53 100644 --- a/vapid.go +++ b/vapid.go @@ -146,6 +146,10 @@ func getVAPIDAuthorizationHeader( vapidKeys *VAPIDKeys, expiration time.Time, ) (string, error) { + if expiration.IsZero() { + expiration = time.Now().Add(time.Hour * 12) + } + // Create the JWT token subURL, err := url.Parse(endpoint) if err != nil { diff --git a/webpush.go b/webpush.go index 87581d2..5cae03d 100644 --- a/webpush.go +++ b/webpush.go @@ -25,24 +25,13 @@ import ( const MaxRecordSize uint32 = 4096 var ( - ErrMaxPadExceeded = errors.New("payload has exceeded the maximum length") + ErrRecordSizeTooSmall = errors.New("record size too small for message") invalidAuthKeyLength = errors.New("invalid auth key length (must be 16)") defaultHTTPClient = &http.Client{} ) -// saltFunc generates a salt of 16 bytes -var saltFunc = func() ([]byte, error) { - salt := make([]byte, 16) - _, err := io.ReadFull(rand.Reader, salt) - if err != nil { - return salt, err - } - - return salt, nil -} - // HTTPClient is an interface for sending the notification HTTP request / testing type HTTPClient interface { Do(*http.Request) (*http.Response, error) @@ -149,36 +138,81 @@ func SendNotification(message []byte, s *Subscription, options *Options) (*http. // Message Encryption for Web Push, and VAPID protocols. // FOR MORE INFORMATION SEE RFC8291: https://datatracker.ietf.org/doc/rfc8291 func SendNotificationWithContext(ctx context.Context, message []byte, s *Subscription, options *Options) (*http.Response, error) { - // Generate 16 byte salt - salt, err := saltFunc() + // Compose message body (RFC8291 encryption of the message) + body, err := composeEncryptedBody(message, s.Keys, options.RecordSize) + if err != nil { + return nil, err + } + + // Get VAPID Authorization header + vapidAuthHeader, err := getVAPIDAuthorizationHeader( + s.Endpoint, + options.Subscriber, + options.VAPIDKeys, + options.VapidExpiration, + ) if err != nil { - return nil, fmt.Errorf("generating salt: %w", err) + return nil, err } - // Create the ecdh_secret shared key pair + // Compose and send the HTTP request + return sendNotification(ctx, s.Endpoint, options, vapidAuthHeader, body) +} + +func composeEncryptedBody(message []byte, keys Keys, recordSize uint32) ([]byte, error) { + // Get the record size + if recordSize == 0 { + recordSize = MaxRecordSize + } else if recordSize < 128 { + return nil, ErrRecordSizeTooSmall + } + + // Allocate buffer to hold the eventual message + // [ header block ] [ ciphertext ] [ 16 byte AEAD tag ], totaling RecordSize bytes + // the ciphertext is the encryption of: [ message ] [ \x02 ] [ 0 or more \x00 as needed ] + recordBuf := make([]byte, recordSize) + // remainingBuf tracks our current writing position in recordBuf: + remainingBuf := recordBuf // Application server key pairs (single use) localPrivateKey, err := ecdh.P256().GenerateKey(rand.Reader) if err != nil { return nil, err } - localPublicKey := localPrivateKey.PublicKey() + // Encryption Content-Coding Header + // +-----------+--------+-----------+---------------+ + // | salt (16) | rs (4) | idlen (1) | keyid (idlen) | + // +-----------+--------+-----------+---------------+ + // in our case the keyid is localPublicKey.Bytes(), so 65 bytes + // First, generate the salt + _, err = rand.Read(remainingBuf[:16]) + if err != nil { + return nil, err + } + salt := remainingBuf[:16] + remainingBuf = remainingBuf[16:] + binary.BigEndian.PutUint32(remainingBuf[:], recordSize) + remainingBuf = remainingBuf[4:] + localPublicKeyBytes := localPublicKey.Bytes() + remainingBuf[0] = byte(len(localPublicKeyBytes)) + remainingBuf = remainingBuf[1:] + copy(remainingBuf[:], localPublicKeyBytes) + remainingBuf = remainingBuf[len(localPublicKeyBytes):] + // Combine application keys with receiver's EC public key to derive ECDH shared secret - sharedECDHSecret, err := localPrivateKey.ECDH(s.Keys.P256dh) + sharedECDHSecret, err := localPrivateKey.ECDH(keys.P256dh) if err != nil { return nil, fmt.Errorf("deriving shared secret: %w", err) } - hash := sha256.New - // ikm prkInfoBuf := bytes.NewBuffer([]byte("WebPush: info\x00")) - prkInfoBuf.Write(s.Keys.P256dh.Bytes()) + prkInfoBuf.Write(keys.P256dh.Bytes()) prkInfoBuf.Write(localPublicKey.Bytes()) - prkHKDF := hkdf.New(hash, sharedECDHSecret, s.Keys.Auth[:], prkInfoBuf.Bytes()) + prkHKDF := hkdf.New(sha256.New, sharedECDHSecret, keys.Auth[:], prkInfoBuf.Bytes()) ikm, err := getHKDFKey(prkHKDF, 32) if err != nil { return nil, err @@ -186,7 +220,7 @@ func SendNotificationWithContext(ctx context.Context, message []byte, s *Subscri // Derive Content Encryption Key contentEncryptionKeyInfo := []byte("Content-Encoding: aes128gcm\x00") - contentHKDF := hkdf.New(hash, ikm, salt, contentEncryptionKeyInfo) + contentHKDF := hkdf.New(sha256.New, ikm, salt, contentEncryptionKeyInfo) contentEncryptionKey, err := getHKDFKey(contentHKDF, 16) if err != nil { return nil, err @@ -194,7 +228,7 @@ func SendNotificationWithContext(ctx context.Context, message []byte, s *Subscri // Derive the Nonce nonceInfo := []byte("Content-Encoding: nonce\x00") - nonceHKDF := hkdf.New(hash, ikm, salt, nonceInfo) + nonceHKDF := hkdf.New(sha256.New, ikm, salt, nonceInfo) nonce, err := getHKDFKey(nonceHKDF, 12) if err != nil { return nil, err @@ -205,46 +239,37 @@ func SendNotificationWithContext(ctx context.Context, message []byte, s *Subscri if err != nil { return nil, err } - gcm, err := cipher.NewGCM(c) if err != nil { return nil, err } - // Get the record size - recordSize := options.RecordSize - if recordSize == 0 { - recordSize = MaxRecordSize - } - - recordLength := int(recordSize) - 16 - - // Encryption Content-Coding Header - recordBuf := bytes.NewBuffer(salt) - - rs := make([]byte, 4) - binary.BigEndian.PutUint32(rs, recordSize) - - recordBuf.Write(rs) - recordBuf.Write([]byte{byte(len(localPublicKey.Bytes()))}) - recordBuf.Write(localPublicKey.Bytes()) - - // Data - dataBuf := bytes.NewBuffer(message) - - // Pad content to max record size - 16 - header - // Padding ending delimeter - dataBuf.Write([]byte("\x02")) - if err := pad(dataBuf, recordLength-recordBuf.Len()); err != nil { - return nil, err + // need 1 byte for the 0x02 delimiter, 16 bytes for the AEAD tag + if len(remainingBuf) < len(message)+17 { + return nil, ErrRecordSizeTooSmall } + // Copy the message plaintext into the buffer + copy(remainingBuf[:], message[:]) + // The plaintext to be encrypted will include the padding delimiter and the padding; + // cut off the final 16 bytes that are reserved for the AEAD tag + plaintext := remainingBuf[:len(remainingBuf)-16] + remainingBuf = remainingBuf[len(message):] + // Add padding delimiter + remainingBuf[0] = '\x02' + remainingBuf = remainingBuf[1:] + // The rest of the buffer is already zero-padded + + // Encipher the plaintext in place, then add the AEAD tag at the end. + // "To reuse plaintext's storage for the encrypted output, use plaintext[:0] + // as dst. Otherwise, the remaining capacity of dst must not overlap plaintext." + gcm.Seal(plaintext[:0], nonce, plaintext, nil) + + return recordBuf, nil +} - // Compose the ciphertext - ciphertext := gcm.Seal([]byte{}, nonce, dataBuf.Bytes(), nil) - recordBuf.Write(ciphertext) - +func sendNotification(ctx context.Context, endpoint string, options *Options, vapidAuthHeader string, body []byte) (*http.Response, error) { // POST request - req, err := http.NewRequest("POST", s.Endpoint, recordBuf) + req, err := http.NewRequest("POST", endpoint, bytes.NewBuffer(body)) if err != nil { return nil, err } @@ -266,22 +291,6 @@ func SendNotificationWithContext(ctx context.Context, message []byte, s *Subscri req.Header.Set("Urgency", string(options.Urgency)) } - expiration := options.VapidExpiration - if expiration.IsZero() { - expiration = time.Now().Add(time.Hour * 12) - } - - // Get VAPID Authorization header - vapidAuthHeader, err := getVAPIDAuthorizationHeader( - s.Endpoint, - options.Subscriber, - options.VAPIDKeys, - expiration, - ) - if err != nil { - return nil, err - } - req.Header.Set("Authorization", vapidAuthHeader) // Send the request @@ -315,17 +324,3 @@ func getHKDFKey(hkdf io.Reader, length int) ([]byte, error) { return key, nil } - -func pad(payload *bytes.Buffer, maxPadLen int) error { - payloadLen := payload.Len() - if payloadLen > maxPadLen { - return ErrMaxPadExceeded - } - - padLen := maxPadLen - payloadLen - - padding := make([]byte, padLen) - payload.Write(padding) - - return nil -} diff --git a/webpush_test.go b/webpush_test.go index 436a495..d8f09a2 100644 --- a/webpush_test.go +++ b/webpush_test.go @@ -106,7 +106,7 @@ func TestSendTooLargeNotification(t *testing.T) { Urgency: "low", }) if err == nil { - t.Fatalf("Error is nil, expected=%s", ErrMaxPadExceeded) + t.Fatalf("Error is nil, expected=%s", ErrRecordSizeTooSmall) } } From 63d6767516e5635741e26029b479af5fc2daa6f6 Mon Sep 17 00:00:00 2001 From: Shivaram Lingamneni Date: Sun, 29 Dec 2024 05:27:10 -0500 Subject: [PATCH 13/18] remove legacy SendNotification API --- end2end_test.go | 3 ++- webpush.go | 19 ++++++++----------- webpush_test.go | 8 ++++---- 3 files changed, 14 insertions(+), 16 deletions(-) diff --git a/end2end_test.go b/end2end_test.go index ff134c5..2cbf4a1 100644 --- a/end2end_test.go +++ b/end2end_test.go @@ -2,6 +2,7 @@ package webpush import ( "bytes" + "context" "crypto/aes" "crypto/cipher" "crypto/ecdh" @@ -139,7 +140,7 @@ func TestEnd2End(t *testing.T) { // ...and the application server uses the subscription to send a push notification sentMessage := "this is our test push notification" - resp, err := SendNotification([]byte(sentMessage), &applicationServer.subscription, &Options{ + resp, err := SendNotification(context.Background(), []byte(sentMessage), &applicationServer.subscription, &Options{ HTTPClient: mockPushService.Client(), VAPIDKeys: applicationServer.vapidKeys, Subscriber: "test@example.com", diff --git a/webpush.go b/webpush.go index 5cae03d..abdc9b1 100644 --- a/webpush.go +++ b/webpush.go @@ -129,17 +129,11 @@ type Subscription struct { Keys Keys `json:"keys"` } -// SendNotification calls SendNotificationWithContext with default context for backwards-compatibility -func SendNotification(message []byte, s *Subscription, options *Options) (*http.Response, error) { - return SendNotificationWithContext(context.Background(), message, s, options) -} - -// SendNotificationWithContext sends a push notification to a subscription's endpoint -// Message Encryption for Web Push, and VAPID protocols. -// FOR MORE INFORMATION SEE RFC8291: https://datatracker.ietf.org/doc/rfc8291 -func SendNotificationWithContext(ctx context.Context, message []byte, s *Subscription, options *Options) (*http.Response, error) { +// SendNotification sends a push notification to a subscription's endpoint, +// applying encryption (RFC 8291) and adding a VAPID header (RFC 8292). +func SendNotification(ctx context.Context, message []byte, s *Subscription, options *Options) (*http.Response, error) { // Compose message body (RFC8291 encryption of the message) - body, err := composeEncryptedBody(message, s.Keys, options.RecordSize) + body, err := EncryptNotification(message, s.Keys, options.RecordSize) if err != nil { return nil, err } @@ -159,7 +153,10 @@ func SendNotificationWithContext(ctx context.Context, message []byte, s *Subscri return sendNotification(ctx, s.Endpoint, options, vapidAuthHeader, body) } -func composeEncryptedBody(message []byte, keys Keys, recordSize uint32) ([]byte, error) { +// EncryptNotification implements the encryption algorithm specified by RFC 8291 for web push +// (RFC 8188's aes128gcm content-encoding, with the key material derived from +// elliptic curve Diffie-Hellman over the P-256 curve). +func EncryptNotification(message []byte, keys Keys, recordSize uint32) ([]byte, error) { // Get the record size if recordSize == 0 { recordSize = MaxRecordSize diff --git a/webpush_test.go b/webpush_test.go index d8f09a2..3633b3c 100644 --- a/webpush_test.go +++ b/webpush_test.go @@ -49,7 +49,7 @@ func TestSendNotificationToURLEncodedSubscription(t *testing.T) { if err != nil { t.Fatal(err) } - resp, err := SendNotification([]byte("Test"), getURLEncodedTestSubscription(), &Options{ + resp, err := SendNotification(context.Background(), []byte("Test"), getURLEncodedTestSubscription(), &Options{ HTTPClient: &testHTTPClient{}, RecordSize: 3070, Subscriber: "", @@ -76,7 +76,7 @@ func TestSendNotificationToStandardEncodedSubscription(t *testing.T) { if err != nil { t.Fatal(err) } - resp, err := SendNotification([]byte("Test"), getStandardEncodedTestSubscription(), &Options{ + resp, err := SendNotification(context.Background(), []byte("Test"), getStandardEncodedTestSubscription(), &Options{ HTTPClient: &testHTTPClient{}, Subscriber: "", Topic: "test_topic", @@ -98,7 +98,7 @@ func TestSendNotificationToStandardEncodedSubscription(t *testing.T) { } func TestSendTooLargeNotification(t *testing.T) { - _, err := SendNotification([]byte(strings.Repeat("Test", int(MaxRecordSize))), getStandardEncodedTestSubscription(), &Options{ + _, err := SendNotification(context.Background(), []byte(strings.Repeat("Test", int(MaxRecordSize))), getStandardEncodedTestSubscription(), &Options{ HTTPClient: &testHTTPClient{}, Subscriber: "", Topic: "test_topic", @@ -129,7 +129,7 @@ func BenchmarkWebPush(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - if _, err := SendNotificationWithContext(ctx, message, sub, &options); err != nil { + if _, err := SendNotification(ctx, message, sub, &options); err != nil { b.Fatal(err) } } From 06328d4749d9dc316793cf2298f1d8ad3e2f1be4 Mon Sep 17 00:00:00 2001 From: Shivaram Lingamneni Date: Sun, 29 Dec 2024 05:27:10 -0500 Subject: [PATCH 14/18] refactor end-to-end test to use ecdh natively --- end2end_test.go | 20 +++++--------------- 1 file changed, 5 insertions(+), 15 deletions(-) diff --git a/end2end_test.go b/end2end_test.go index 2cbf4a1..2995ec1 100644 --- a/end2end_test.go +++ b/end2end_test.go @@ -7,7 +7,6 @@ import ( "crypto/cipher" "crypto/ecdh" "crypto/ecdsa" - "crypto/elliptic" "crypto/rand" "crypto/sha256" "encoding/base64" @@ -34,7 +33,7 @@ func TestEnd2End(t *testing.T) { // the data known to the user agent (browser) userAgent struct { publicVAPIDKey *ecdsa.PublicKey - subscriptionKey *ecdsa.PrivateKey + subscriptionKey *ecdh.PrivateKey authSecret [16]byte subscription Subscription receivedNotifications [][]byte @@ -109,12 +108,11 @@ func TestEnd2End(t *testing.T) { // what follows is the equivalent of PushManager.subscribe() in JS { // the user agent generates its own key pair so it can be sent encrypted messages - userAgent.subscriptionKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + userAgent.subscriptionKey, err = ecdh.P256().GenerateKey(rand.Reader) if err != nil { t.Fatalf("generating user agent keys: %s", err) } - // we need the ECDH representation - ecdhPublicKey, err := userAgent.subscriptionKey.PublicKey.ECDH() + ecdhPublicKey := userAgent.subscriptionKey.PublicKey() if err != nil { t.Fatalf("converting user agent public key to ECDH: %s", err) } @@ -267,7 +265,7 @@ func parseJWT(rawJWT string, applicationServerKey *ecdsa.PublicKey) (*jwt.Token, return token, nil } -func decodeNotification(body []byte, authSecret [16]byte, userAgentKey *ecdsa.PrivateKey) (string, error) { +func decodeNotification(body []byte, authSecret [16]byte, userAgentECDHKey *ecdh.PrivateKey) (string, error) { // remember initial body length, before we start consuming it bodyLen := len(body) // the body is aes128gcm-encoded as described in RFC8188, @@ -293,14 +291,6 @@ func decodeNotification(body []byte, authSecret [16]byte, userAgentKey *ecdsa.Pr if err != nil { return "", fmt.Errorf("converting public key to ECDH: %w", err) } - userAgentECDHKey, err := userAgentKey.ECDH() - if err != nil { - return "", fmt.Errorf("converting user agent private key to ECDH: %w", err) - } - userAgentECDHPublicKey, err := userAgentKey.PublicKey.ECDH() - if err != nil { - return "", fmt.Errorf("converting user agent public key to ECDH: %w", err) - } sharedECDHSecret, err := userAgentECDHKey.ECDH(pubKeyECDH) if err != nil { @@ -311,7 +301,7 @@ func decodeNotification(body []byte, authSecret [16]byte, userAgentKey *ecdsa.Pr // ikm prkInfoBuf := bytes.NewBuffer([]byte("WebPush: info\x00")) - prkInfoBuf.Write(userAgentECDHPublicKey.Bytes()) // aka "dh" + prkInfoBuf.Write(userAgentECDHKey.PublicKey().Bytes()) // aka "dh" prkInfoBuf.Write(pubKeyECDH.Bytes()) prkHKDF := hkdf.New(hash, sharedECDHSecret, authSecret[:], prkInfoBuf.Bytes()) From e6f4fd4f4c114cf910d163decb479ec6b3576978 Mon Sep 17 00:00:00 2001 From: Shivaram Lingamneni Date: Sun, 29 Dec 2024 05:27:10 -0500 Subject: [PATCH 15/18] add a fuzz test --- end2end_test.go | 36 ++++++++++++++++++++++++++++++++++-- 1 file changed, 34 insertions(+), 2 deletions(-) diff --git a/end2end_test.go b/end2end_test.go index 2995ec1..84b70a0 100644 --- a/end2end_test.go +++ b/end2end_test.go @@ -167,7 +167,7 @@ func TestEnd2End(t *testing.T) { // the push service then forwards the notification to the user agent userAgent.receivedNotifications = pushService.receivedNotifications // and the user agent can decrypt them - receivedMessage, err := decodeNotification(userAgent.receivedNotifications[0], userAgent.authSecret, userAgent.subscriptionKey) + receivedMessage, err := decryptNotification(userAgent.receivedNotifications[0], userAgent.authSecret, userAgent.subscriptionKey) if err != nil { t.Fatalf("error decrypting notification in user agent: %s", err) } @@ -265,7 +265,7 @@ func parseJWT(rawJWT string, applicationServerKey *ecdsa.PublicKey) (*jwt.Token, return token, nil } -func decodeNotification(body []byte, authSecret [16]byte, userAgentECDHKey *ecdh.PrivateKey) (string, error) { +func decryptNotification(body []byte, authSecret [16]byte, userAgentECDHKey *ecdh.PrivateKey) (string, error) { // remember initial body length, before we start consuming it bodyLen := len(body) // the body is aes128gcm-encoded as described in RFC8188, @@ -361,6 +361,38 @@ func decodeNotification(body []byte, authSecret [16]byte, userAgentECDHKey *ecdh return string(res), nil } +func TestRandomRoundTrip(t *testing.T) { + privKey, err := ecdh.P256().GenerateKey(rand.Reader) + if err != nil { + t.Fatal(err) + } + keys := Keys{ + P256dh: privKey.PublicKey(), + } + _, err = rand.Read(keys.Auth[:]) + if err != nil { + t.Fatal(err) + } + + for length := 0; length < 1900; length++ { + message := make([]byte, length) + if _, err := rand.Read(message); err != nil { + t.Fatal(err) + } + encrypted, err := EncryptNotification(message, keys, 2048) + if err != nil { + t.Fatal(err) + } + decrypted, err := decryptNotification(encrypted, keys.Auth, privKey) + if err != nil { + t.Fatal(err) + } + if string(message) != decrypted { + t.Fatalf("round trip failed at message length %d", length) + } + } +} + func Test_decodeVAPIDPublicKey(t *testing.T) { vapidKeys, err := GenerateVAPIDKeys() if err != nil { From 92752f02cb56bd233a72e469a70a7313b6dcd598 Mon Sep 17 00:00:00 2001 From: Shivaram Lingamneni Date: Sun, 29 Dec 2024 05:27:52 -0500 Subject: [PATCH 16/18] bump to v2 --- go.mod | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go.mod b/go.mod index b2c7db0..f1336f4 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module github.com/SherClockHolmes/webpush-go +module github.com/ergochat/webpush-go/v2 require ( github.com/golang-jwt/jwt/v5 v5.2.1 From 2decbc9c98f3702617d4c38eb367ad418395a7ad Mon Sep 17 00:00:00 2001 From: Shivaram Lingamneni Date: Sun, 29 Dec 2024 05:29:47 -0500 Subject: [PATCH 17/18] add github CI step --- .github/dependabot.yml | 8 -------- .github/workflows/build.yml | 24 ++++++++++++++++++++++++ 2 files changed, 24 insertions(+), 8 deletions(-) delete mode 100644 .github/dependabot.yml create mode 100644 .github/workflows/build.yml diff --git a/.github/dependabot.yml b/.github/dependabot.yml deleted file mode 100644 index e8c6655..0000000 --- a/.github/dependabot.yml +++ /dev/null @@ -1,8 +0,0 @@ -version: 2 -updates: -- package-ecosystem: gomod - directory: "/" - schedule: - interval: weekly - time: "13:00" - open-pull-requests-limit: 10 diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml new file mode 100644 index 0000000..1aa845b --- /dev/null +++ b/.github/workflows/build.yml @@ -0,0 +1,24 @@ +name: "build" + +on: + pull_request: + branches: + - "master" + - "stable" + push: + branches: + - "master" + - "stable" + +jobs: + build: + runs-on: "ubuntu-22.04" + steps: + - name: "checkout repository" + uses: "actions/checkout@v3" + - name: "setup go" + uses: "actions/setup-go@v3" + with: + go-version: "1.23" + - name: "make test" + run: "make test" From 73cb41c50fba50a0e4729419cb6c4e30ae5837df Mon Sep 17 00:00:00 2001 From: Shivaram Lingamneni Date: Sun, 5 Jan 2025 00:32:20 -0500 Subject: [PATCH 18/18] documentation updates --- .gitignore | 2 ++ CHANGELOG.md | 14 ++++++++++++++ README.md | 18 ++++++++++-------- example/main.go | 30 ++++++++++++++++++++---------- 4 files changed, 46 insertions(+), 18 deletions(-) create mode 100644 CHANGELOG.md diff --git a/.gitignore b/.gitignore index 13b7c32..3ab8789 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,5 @@ vendor/** .DS_Store *.out + +*.swp diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..3dd34d5 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,14 @@ +# Changelog +All notable changes to webpush-go will be documented in this file. + +## [2.0.0] - 2025-01-01 + +* Update the `Keys` struct definition to store `Auth` as `[16]byte` and `P256dh` as `*ecdh.PublicKey` + * `Keys` can no longer be compared with `==`; use `(*Keys.Equal)` instead + * The JSON representation has not changed and is backwards and forwards compatible with v1 + * `DecodeSubscriptionKeys` is a helper to decode base64-encoded auth and p256dh parameters into a `Keys`, with validation +* Update the `VAPIDKeys` struct to contain a `(*ecdsa.PrivateKey)` + * `VAPIDKeys` can no longer be compared with `==`; use `(*VAPIDKeys).Equal` instead + * The JSON representation is now a JSON string containing the PEM of the PKCS8-encoded private key + * To parse the legacy representation (raw bytes of the private key encoded in base64), use `DecodeLegacyVAPIDPrivateKey` +* Renamed `SendNotificationWithContext` to `SendNotification`, removing the earlier `SendNotification` API. (Pass `context.Background()` as the context to restore the former behavior.) diff --git a/README.md b/README.md index 3a0ab13..461b4c8 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,13 @@ # webpush-go -[![Go Report Card](https://goreportcard.com/badge/github.com/SherClockHolmes/webpush-go)](https://goreportcard.com/report/github.com/SherClockHolmes/webpush-go) -[![GoDoc](https://godoc.org/github.com/SherClockHolmes/webpush-go?status.svg)](https://godoc.org/github.com/SherClockHolmes/webpush-go) +[![GoDoc](https://godoc.org/github.com/ergochat/webpush-go?status.svg)](https://godoc.org/github.com/ergochat/webpush-go) Web Push API Encryption with VAPID support. +This library is a fork of [SherClockHolmes/webpush-go](https://github.com/SherClockHolmes/webpush-go). + ```bash -go get -u github.com/SherClockHolmes/webpush-go +go get -u github.com/ergochat/webpush-go/v2 ``` ## Example @@ -19,20 +20,21 @@ package main import ( "encoding/json" - webpush "github.com/SherClockHolmes/webpush-go" + webpush "github.com/ergochat/webpush-go/v2" ) func main() { // Decode subscription s := &webpush.Subscription{} json.Unmarshal([]byte(""), s) + vapidKeys := new(webpush.VAPIDKeys) + json.Unmarshal([]byte("), vapidKeys) // Send Notification resp, err := webpush.SendNotification([]byte("Test"), s, &webpush.Options{ Subscriber: "example@example.com", - VAPIDPublicKey: "", - VAPIDPrivateKey: "", - TTL: 30, + VAPIDKeys: vapidKeys, + TTL: 3600, // seconds }) if err != nil { // TODO: Handle error @@ -46,7 +48,7 @@ func main() { Use the helper method `GenerateVAPIDKeys` to generate the VAPID key pair. ```golang -privateKey, publicKey, err := webpush.GenerateVAPIDKeys() +vapidKeys, err := webpush.GenerateVAPIDKeys() if err != nil { // TODO: Handle error } diff --git a/example/main.go b/example/main.go index 3445cdb..4769c13 100644 --- a/example/main.go +++ b/example/main.go @@ -1,29 +1,39 @@ package main import ( + "context" "encoding/json" + "time" - webpush "github.com/SherClockHolmes/webpush-go" + webpush "github.com/ergochat/webpush-go/v2" ) const ( subscription = `` - vapidPublicKey = "" vapidPrivateKey = "" ) func main() { // Decode subscription - s := &webpush.Subscription{} - json.Unmarshal([]byte(subscription), s) + sub := &webpush.Subscription{} + json.Unmarshal([]byte(subscription), sub) + // Decode VAPID keys + v := &webpush.VAPIDKeys{} + json.Unmarshal([]byte(vapidPrivateKey), v) // Send Notification - resp, err := webpush.SendNotification([]byte("Test"), s, &webpush.Options{ - Subscriber: "example@example.com", // Do not include "mailto:" - VAPIDPublicKey: vapidPublicKey, - VAPIDPrivateKey: vapidPrivateKey, - TTL: 30, - }) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + resp, err := webpush.SendNotification( + ctx, + []byte("Test"), + sub, + &webpush.Options{ + Subscriber: "example@example.com", // Do not include "mailto:" + VAPIDKeys: v, + TTL: 30, + }, + ) if err != nil { // TODO: Handle error }