diff --git a/.check-gofmt.sh b/.check-gofmt.sh new file mode 100755 index 0000000..48a1aa3 --- /dev/null +++ b/.check-gofmt.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +SOURCES="." + +if [ "$1" = "--fix" ]; then + exec gofmt -s -w $SOURCES +fi + +if [ -n "$(gofmt -s -l $SOURCES)" ]; then + echo "Go code is not formatted correctly with \`gofmt -s\`:" + gofmt -s -d $SOURCES + exit 1 +fi diff --git a/.github/dependabot.yml b/.github/dependabot.yml deleted file mode 100644 index e8c6655..0000000 --- a/.github/dependabot.yml +++ /dev/null @@ -1,8 +0,0 @@ -version: 2 -updates: -- package-ecosystem: gomod - directory: "/" - schedule: - interval: weekly - time: "13:00" - open-pull-requests-limit: 10 diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml new file mode 100644 index 0000000..1aa845b --- /dev/null +++ b/.github/workflows/build.yml @@ -0,0 +1,24 @@ +name: "build" + +on: + pull_request: + branches: + - "master" + - "stable" + push: + branches: + - "master" + - "stable" + +jobs: + build: + runs-on: "ubuntu-22.04" + steps: + - name: "checkout repository" + uses: "actions/checkout@v3" + - name: "setup go" + uses: "actions/setup-go@v3" + with: + go-version: "1.23" + - name: "make test" + run: "make test" diff --git a/.gitignore b/.gitignore index 13b7c32..3ab8789 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,5 @@ vendor/** .DS_Store *.out + +*.swp diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..3dd34d5 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,14 @@ +# Changelog +All notable changes to webpush-go will be documented in this file. + +## [2.0.0] - 2025-01-01 + +* Update the `Keys` struct definition to store `Auth` as `[16]byte` and `P256dh` as `*ecdh.PublicKey` + * `Keys` can no longer be compared with `==`; use `(*Keys.Equal)` instead + * The JSON representation has not changed and is backwards and forwards compatible with v1 + * `DecodeSubscriptionKeys` is a helper to decode base64-encoded auth and p256dh parameters into a `Keys`, with validation +* Update the `VAPIDKeys` struct to contain a `(*ecdsa.PrivateKey)` + * `VAPIDKeys` can no longer be compared with `==`; use `(*VAPIDKeys).Equal` instead + * The JSON representation is now a JSON string containing the PEM of the PKCS8-encoded private key + * To parse the legacy representation (raw bytes of the private key encoded in base64), use `DecodeLegacyVAPIDPrivateKey` +* Renamed `SendNotificationWithContext` to `SendNotification`, removing the earlier `SendNotification` API. (Pass `context.Background()` as the context to restore the former behavior.) diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..7b72a66 --- /dev/null +++ b/Makefile @@ -0,0 +1,6 @@ +.PHONY: test + +test: + go test . + go vet . + ./.check-gofmt.sh diff --git a/README.md b/README.md index c313fc6..461b4c8 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,13 @@ # webpush-go -[![Go Report Card](https://goreportcard.com/badge/github.com/SherClockHolmes/webpush-go)](https://goreportcard.com/report/github.com/SherClockHolmes/webpush-go) -[![GoDoc](https://godoc.org/github.com/SherClockHolmes/webpush-go?status.svg)](https://godoc.org/github.com/SherClockHolmes/webpush-go) +[![GoDoc](https://godoc.org/github.com/ergochat/webpush-go?status.svg)](https://godoc.org/github.com/ergochat/webpush-go) Web Push API Encryption with VAPID support. +This library is a fork of [SherClockHolmes/webpush-go](https://github.com/SherClockHolmes/webpush-go). + ```bash -go get -u github.com/SherClockHolmes/webpush-go +go get -u github.com/ergochat/webpush-go/v2 ``` ## Example @@ -19,20 +20,21 @@ package main import ( "encoding/json" - webpush "github.com/SherClockHolmes/webpush-go" + webpush "github.com/ergochat/webpush-go/v2" ) func main() { // Decode subscription s := &webpush.Subscription{} json.Unmarshal([]byte(""), s) + vapidKeys := new(webpush.VAPIDKeys) + json.Unmarshal([]byte("), vapidKeys) // Send Notification resp, err := webpush.SendNotification([]byte("Test"), s, &webpush.Options{ Subscriber: "example@example.com", - VAPIDPublicKey: "", - VAPIDPrivateKey: "", - TTL: 30, + VAPIDKeys: vapidKeys, + TTL: 3600, // seconds }) if err != nil { // TODO: Handle error @@ -46,7 +48,7 @@ func main() { Use the helper method `GenerateVAPIDKeys` to generate the VAPID key pair. ```golang -privateKey, publicKey, err := webpush.GenerateVAPIDKeys() +vapidKeys, err := webpush.GenerateVAPIDKeys() if err != nil { // TODO: Handle error } @@ -54,7 +56,7 @@ if err != nil { ## Development -1. Install [Go 1.11+](https://golang.org/) +1. Install [Go 1.20+](https://golang.org/) 2. `go mod vendor` 3. `go test` diff --git a/end2end_test.go b/end2end_test.go new file mode 100644 index 0000000..84b70a0 --- /dev/null +++ b/end2end_test.go @@ -0,0 +1,410 @@ +package webpush + +import ( + "bytes" + "context" + "crypto/aes" + "crypto/cipher" + "crypto/ecdh" + "crypto/ecdsa" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/binary" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strconv" + "strings" + "testing" + + jwt "github.com/golang-jwt/jwt/v5" + "golang.org/x/crypto/hkdf" +) + +func TestEnd2End(t *testing.T) { + var ( + // the data known to the application server (backend, which uses webpush-go) + applicationServer struct { + vapidKeys *VAPIDKeys + subscription Subscription + } + // the data known to the user agent (browser) + userAgent struct { + publicVAPIDKey *ecdsa.PublicKey + subscriptionKey *ecdh.PrivateKey + authSecret [16]byte + subscription Subscription + receivedNotifications [][]byte + } + // the data known to the push server (which receives push messages on behalf of the user agent, e.g. Firestore) + pushService struct { + applicationServerKey *ecdsa.PublicKey + receivedNotifications [][]byte + } + + err error + ) + + // a VAPID key pair for the application server, usually only generated once and reused + applicationServer.vapidKeys, err = GenerateVAPIDKeys() + if err != nil { + t.Fatalf("generating VAPID keys: %s", err) + } + + // The application server needs to inform the user agent of the public VAPID key. + // (We decode it first for ease of use.) + userAgent.publicVAPIDKey = applicationServer.vapidKeys.privateKey.Public().(*ecdsa.PublicKey) + + // We need a mock push service for webpush-go to send notifications to. + var mockPushService *httptest.Server + mockPushService = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // check that there's a valid vapid JWT + token, err := parseVapidAuthHeader( + r.Header.Get("Authorization"), + // by the time this function is called, this value will be set (see PushManager.subscribe() below) + pushService.applicationServerKey) + if err != nil { + w.WriteHeader(http.StatusUnauthorized) + _, _ = fmt.Fprintf(w, "invalid auth: %s", err) + return + } + // verify that the audience matches our URL + aud := token.Claims.(jwt.MapClaims)["aud"] + if aud != mockPushService.URL { + w.WriteHeader(http.StatusUnauthorized) + _, _ = fmt.Fprintf(w, "JWT has bad audience, want %q, got %q", mockPushService.URL, aud) + return + } + // RFC8188 only allows for exactly one content encoding + if contentEncoding := r.Header.Get("Content-Encoding"); contentEncoding != "aes128gcm" { + w.WriteHeader(http.StatusBadRequest) + _, _ = fmt.Fprintf(w, "unsupported Content-Encoding, want %q, got %q", "aes128gcm", contentEncoding) + return + } + body, err := io.ReadAll(r.Body) + if err != nil { + // this suggests a broken connection, so log the error instead of sending it back + t.Errorf("failed to read request body: %s", err) + w.WriteHeader(http.StatusInternalServerError) + return + } + contentLength, err := strconv.Atoi(r.Header.Get("Content-Length")) + if err != nil { + t.Errorf("invalid content-length `%s`: %v", r.Header.Get("Content-Length"), err) + } + if len(body) != contentLength { + t.Errorf("body length %d did not match content-length header %d", len(body), contentLength) + } + // store body for later decoding by user agent + // (the push service doesn't have the key required for decryption) + pushService.receivedNotifications = append(pushService.receivedNotifications, body) + + w.WriteHeader(http.StatusAccepted) + })) + defer mockPushService.Close() + + // what follows is the equivalent of PushManager.subscribe() in JS + { + // the user agent generates its own key pair so it can be sent encrypted messages + userAgent.subscriptionKey, err = ecdh.P256().GenerateKey(rand.Reader) + if err != nil { + t.Fatalf("generating user agent keys: %s", err) + } + ecdhPublicKey := userAgent.subscriptionKey.PublicKey() + if err != nil { + t.Fatalf("converting user agent public key to ECDH: %s", err) + } + // generate the shared auth secret + _, err = rand.Read(userAgent.authSecret[:]) + if err != nil { + t.Fatalf("generating user agent auth secret: %s", err) + } + // the user agent then performs a registration with the push service using that key, + // while also letting the push service know the application server key to expect. + pushService.applicationServerKey = userAgent.publicVAPIDKey + userAgent.subscription = Subscription{ + Keys: Keys{ + Auth: userAgent.authSecret, + P256dh: ecdhPublicKey, + }, + Endpoint: mockPushService.URL, + } + } + + // the user agent sends its subscription to the application server... + applicationServer.subscription = userAgent.subscription + + // ...and the application server uses the subscription to send a push notification + sentMessage := "this is our test push notification" + resp, err := SendNotification(context.Background(), []byte(sentMessage), &applicationServer.subscription, &Options{ + HTTPClient: mockPushService.Client(), + VAPIDKeys: applicationServer.vapidKeys, + Subscriber: "test@example.com", + }) + if err != nil { + t.Fatalf("failed to send notification: %s", err) + } + defer func() { + if err := resp.Body.Close(); err != nil { + t.Errorf("error closing mock push service response body: %s", err) + } + }() + // check for success + respBody, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("reading mock push service response body: %s", err) + } + if resp.StatusCode/100 != 2 { + t.Errorf("unexpected push service status code %d, body: %s", resp.StatusCode, respBody) + } + + // the push server should now have received the notification + if l := len(pushService.receivedNotifications); l != 1 { + t.Fatalf("Want 1 notification received by push service, got %d", l) + } + // the push service then forwards the notification to the user agent + userAgent.receivedNotifications = pushService.receivedNotifications + // and the user agent can decrypt them + receivedMessage, err := decryptNotification(userAgent.receivedNotifications[0], userAgent.authSecret, userAgent.subscriptionKey) + if err != nil { + t.Fatalf("error decrypting notification in user agent: %s", err) + } + if receivedMessage != sentMessage { + t.Errorf("Sent notification %q, but got %q", sentMessage, receivedMessage) + } +} + +func decodeVAPIDPublicKey(publicVAPIDKey string) (*ecdsa.PublicKey, error) { + publicVAPIDKeyBytes, err := base64.RawURLEncoding.DecodeString(publicVAPIDKey) + if err != nil { + return nil, fmt.Errorf("base64-decoding public VAPID key: %w", err) + } + return decodeECDSAPublicKey(publicVAPIDKeyBytes) +} + +func decodeECDSAPublicKey(bytes []byte) (*ecdsa.PublicKey, error) { + ecdhKey, err := ecdh.P256().NewPublicKey(bytes) + if err != nil { + return nil, fmt.Errorf("parsing public VAPID key: %w", err) + } + res, err := ecdhPublicKeyToECDSA(ecdhKey) + if err != nil { + return nil, fmt.Errorf("converting public VAPID key from *ecdh.PublicKey to *ecdsa.PublicKey: %w", err) + } + return res, nil +} + +func parseVapidAuthHeader(authHeader string, applicationServerKey *ecdsa.PublicKey) (*jwt.Token, error) { + if authHeader == "" { + return nil, fmt.Errorf("missing auth header") + } + // the Authorization header should be of the form "vapid t=JWT, k=key" (RFC8292) + // we need to extract the JWT (JSON Web Token) from t to check the signature using k + authBody, found := strings.CutPrefix(authHeader, "vapid ") + if !found { + return nil, fmt.Errorf("Authorization header is not vapid: %s", authHeader) + } + authFields := strings.Split(authBody, ",") + rawJWT := "" + rawKey := "" + for _, field := range authFields { + kv := strings.SplitN(field, "=", 2) + if len(kv) < 2 { + return nil, fmt.Errorf("push service vapid Authorization header field %q malformed", field) + } + key := strings.TrimSpace(kv[0]) + val := strings.TrimSpace(kv[1]) + switch key { + case "t": + rawJWT = val + case "k": + rawKey = val + default: + // other fields irrelevant to us + } + } + if rawJWT == "" { + return nil, fmt.Errorf("vapid Authorization header lacks \"t\" field (JWT)") + } + if rawKey == "" { + return nil, fmt.Errorf("vapid Authorization header lacks \"k\" field") + } + key, err := decodeVAPIDPublicKey(rawKey) + if err != nil { + return nil, fmt.Errorf("parsing vapid Authorization key: %w", err) + } + // check that the key matches the known applicationServerKey + // (RFC8292 4.2) + if !key.Equal(applicationServerKey) { + // in real code, this would mean the user agent needs to resubscribe with the new applicationServerKey + return nil, fmt.Errorf("vapid Authorization key does not match applicationServerKey from subscription") + } + + // verify the JWT signature + token, err := parseJWT(rawJWT, key) + if err != nil { + return nil, fmt.Errorf("parsing vapid Authorization JWT: %w", err) + } + return token, nil +} + +func parseJWT(rawJWT string, applicationServerKey *ecdsa.PublicKey) (*jwt.Token, error) { + token, err := jwt.Parse(rawJWT, func(t *jwt.Token) (interface{}, error) { + switch t.Method.Alg() { + case "ES256": + return applicationServerKey, nil + default: + return nil, fmt.Errorf("unsupported JWT signing alg %q", t.Method.Alg()) + } + }) + if err != nil { + return nil, fmt.Errorf("decoding JWT %s: %w", rawJWT, err) + } + return token, nil +} + +func decryptNotification(body []byte, authSecret [16]byte, userAgentECDHKey *ecdh.PrivateKey) (string, error) { + // remember initial body length, before we start consuming it + bodyLen := len(body) + // the body is aes128gcm-encoded as described in RFC8188, + // starting with this header: + // +-----------+--------+-----------+---------------+ + // | salt (16) | rs (4) | idlen (1) | keyid (idlen) | + // +-----------+--------+-----------+---------------+ + salt, body := body[:16], body[16:] + recordSize, body := int(binary.BigEndian.Uint32(body[:4])), body[4:] + idLen, body := int(uint8(body[0])), body[1:] + rawPubKey, body := body[:idLen], body[idLen:] + if bodyLen != recordSize { + // this could mean a multi-record message was sent, this simplified parser does not support those. + return "", fmt.Errorf("expected body length %d, got %d", recordSize, bodyLen) + } + + // parse keys and derive shared secret + pubKey, err := decodeECDSAPublicKey(rawPubKey) + if err != nil { + return "", fmt.Errorf("decoding public key from header: %w", err) + } + pubKeyECDH, err := pubKey.ECDH() + if err != nil { + return "", fmt.Errorf("converting public key to ECDH: %w", err) + } + + sharedECDHSecret, err := userAgentECDHKey.ECDH(pubKeyECDH) + if err != nil { + return "", fmt.Errorf("deriving shared secret from notification public key and user agent private key: %w", err) + } + + hash := sha256.New + + // ikm + prkInfoBuf := bytes.NewBuffer([]byte("WebPush: info\x00")) + prkInfoBuf.Write(userAgentECDHKey.PublicKey().Bytes()) // aka "dh" + prkInfoBuf.Write(pubKeyECDH.Bytes()) + + prkHKDF := hkdf.New(hash, sharedECDHSecret, authSecret[:], prkInfoBuf.Bytes()) + ikm, err := getHKDFKey(prkHKDF, 32) + if err != nil { + return "", fmt.Errorf("deriving ikm: %w", err) + } + + // Derive Content Encryption Key + contentEncryptionKeyInfo := []byte("Content-Encoding: aes128gcm\x00") + contentHKDF := hkdf.New(hash, ikm, salt, contentEncryptionKeyInfo) + contentEncryptionKey, err := getHKDFKey(contentHKDF, 16) + if err != nil { + return "", fmt.Errorf("deriving content encryption key: %w", err) + } + + // Derive the Nonce + nonceInfo := []byte("Content-Encoding: nonce\x00") + nonceHKDF := hkdf.New(hash, ikm, salt, nonceInfo) + nonce, err := getHKDFKey(nonceHKDF, 12) + if err != nil { + return "", fmt.Errorf("deriving nonce: %w", err) + } + + // Cipher + c, err := aes.NewCipher(contentEncryptionKey) + if err != nil { + return "", fmt.Errorf("creating cipher block: %w", err) + } + + gcm, err := cipher.NewGCM(c) + if err != nil { + return "", fmt.Errorf("creating GCM: %w", err) + } + + // Decrypt + res, err := gcm.Open(nil, nonce, body, nil) + if err != nil { + return "", fmt.Errorf("decrypting: %w", err) + } + + // the message is padded with 0x02 0x00 0x00 0x00 [...] 0x00, we need to remove that + lastNull := len(res) + for ; lastNull > 0 && res[lastNull-1] == 0x00; lastNull-- { + } + if lastNull == 0 { + // we expect at least one 0x02 (or 0x01) before the nulls, not finding one is wrong + return "", fmt.Errorf("decryption yielded only %d null bytes", len(res)) + } + if beforeNull := res[lastNull-1]; beforeNull != 0x02 { + // if we get an 0x01, it means we have a multi-record message, this mock does not implement those + return "", fmt.Errorf("padding nulls in decrypted message should be preceded by 0x02 delimiter, got %02X", beforeNull) + } + // strip trailing nulls and separating 0x02 + res = res[:lastNull-1] + + return string(res), nil +} + +func TestRandomRoundTrip(t *testing.T) { + privKey, err := ecdh.P256().GenerateKey(rand.Reader) + if err != nil { + t.Fatal(err) + } + keys := Keys{ + P256dh: privKey.PublicKey(), + } + _, err = rand.Read(keys.Auth[:]) + if err != nil { + t.Fatal(err) + } + + for length := 0; length < 1900; length++ { + message := make([]byte, length) + if _, err := rand.Read(message); err != nil { + t.Fatal(err) + } + encrypted, err := EncryptNotification(message, keys, 2048) + if err != nil { + t.Fatal(err) + } + decrypted, err := decryptNotification(encrypted, keys.Auth, privKey) + if err != nil { + t.Fatal(err) + } + if string(message) != decrypted { + t.Fatalf("round trip failed at message length %d", length) + } + } +} + +func Test_decodeVAPIDPublicKey(t *testing.T) { + vapidKeys, err := GenerateVAPIDKeys() + if err != nil { + t.Fatalf("generating VAPID keys: %s", err) + } + + // now decode using our test helper and compare the results + gotPubKey, err := decodeVAPIDPublicKey(vapidKeys.publicKey) + if err != nil { + t.Fatalf("decoding public key") + } + if !gotPubKey.Equal(vapidKeys.privateKey.Public()) { + t.Errorf("result differs:\ngot: %v\nwant: %v", gotPubKey, vapidKeys.privateKey.Public()) + } +} diff --git a/example/main.go b/example/main.go index 3445cdb..4769c13 100644 --- a/example/main.go +++ b/example/main.go @@ -1,29 +1,39 @@ package main import ( + "context" "encoding/json" + "time" - webpush "github.com/SherClockHolmes/webpush-go" + webpush "github.com/ergochat/webpush-go/v2" ) const ( subscription = `` - vapidPublicKey = "" vapidPrivateKey = "" ) func main() { // Decode subscription - s := &webpush.Subscription{} - json.Unmarshal([]byte(subscription), s) + sub := &webpush.Subscription{} + json.Unmarshal([]byte(subscription), sub) + // Decode VAPID keys + v := &webpush.VAPIDKeys{} + json.Unmarshal([]byte(vapidPrivateKey), v) // Send Notification - resp, err := webpush.SendNotification([]byte("Test"), s, &webpush.Options{ - Subscriber: "example@example.com", // Do not include "mailto:" - VAPIDPublicKey: vapidPublicKey, - VAPIDPrivateKey: vapidPrivateKey, - TTL: 30, - }) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + resp, err := webpush.SendNotification( + ctx, + []byte("Test"), + sub, + &webpush.Options{ + Subscriber: "example@example.com", // Do not include "mailto:" + VAPIDKeys: v, + TTL: 30, + }, + ) if err != nil { // TODO: Handle error } diff --git a/go.mod b/go.mod index 6b0604f..f1336f4 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,8 @@ -module github.com/SherClockHolmes/webpush-go +module github.com/ergochat/webpush-go/v2 require ( - github.com/golang-jwt/jwt v3.2.2+incompatible + github.com/golang-jwt/jwt/v5 v5.2.1 golang.org/x/crypto v0.9.0 ) -go 1.13 +go 1.20 diff --git a/go.sum b/go.sum index d9575c4..4a47b67 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= +github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= +github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= diff --git a/legacy.go b/legacy.go new file mode 100644 index 0000000..b151da9 --- /dev/null +++ b/legacy.go @@ -0,0 +1,76 @@ +package webpush + +import ( + "crypto/ecdh" + "crypto/ecdsa" + "crypto/elliptic" + "encoding/base64" + "fmt" + "math/big" +) + +// ecdhPublicKeyToECDSA converts an ECDH key to an ECDSA key. +// This is deprecated as per https://github.com/golang/go/issues/63963 +// but we need to do it in order to parse the legacy private key format. +func ecdhPublicKeyToECDSA(key *ecdh.PublicKey) (*ecdsa.PublicKey, error) { + rawKey := key.Bytes() + switch key.Curve() { + case ecdh.P256(): + return &ecdsa.PublicKey{ + Curve: elliptic.P256(), + X: big.NewInt(0).SetBytes(rawKey[1:33]), + Y: big.NewInt(0).SetBytes(rawKey[33:]), + }, nil + case ecdh.P384(): + return &ecdsa.PublicKey{ + Curve: elliptic.P384(), + X: big.NewInt(0).SetBytes(rawKey[1:49]), + Y: big.NewInt(0).SetBytes(rawKey[49:]), + }, nil + case ecdh.P521(): + return &ecdsa.PublicKey{ + Curve: elliptic.P521(), + X: big.NewInt(0).SetBytes(rawKey[1:67]), + Y: big.NewInt(0).SetBytes(rawKey[67:]), + }, nil + default: + return nil, fmt.Errorf("cannot convert non-NIST *ecdh.PublicKey to *ecdsa.PublicKey") + } +} + +func ecdhPrivateKeyToECDSA(key *ecdh.PrivateKey) (*ecdsa.PrivateKey, error) { + // see https://github.com/golang/go/issues/63963 + pubKey, err := ecdhPublicKeyToECDSA(key.PublicKey()) + if err != nil { + return nil, fmt.Errorf("converting PublicKey part of *ecdh.PrivateKey: %w", err) + } + return &ecdsa.PrivateKey{ + PublicKey: *pubKey, + D: big.NewInt(0).SetBytes(key.Bytes()), + }, nil +} + +// DecodeLegacyVAPIDPrivateKey decodes the legacy string private key format +// returned by GenerateVAPIDKeys in v1. +func DecodeLegacyVAPIDPrivateKey(key string) (*VAPIDKeys, error) { + bytes, err := decodeSubscriptionKey(key) + if err != nil { + return nil, err + } + + ecdhPrivKey, err := ecdh.P256().NewPrivateKey(bytes) + if err != nil { + return nil, err + } + + ecdsaPrivKey, err := ecdhPrivateKeyToECDSA(ecdhPrivKey) + if err != nil { + return nil, err + } + + publicKey := base64.RawURLEncoding.EncodeToString(ecdhPrivKey.PublicKey().Bytes()) + return &VAPIDKeys{ + privateKey: ecdsaPrivKey, + publicKey: publicKey, + }, nil +} diff --git a/legacy_test.go b/legacy_test.go new file mode 100644 index 0000000..9d7987b --- /dev/null +++ b/legacy_test.go @@ -0,0 +1,55 @@ +package webpush + +import ( + "encoding/json" + "testing" +) + +const ( + legacyKeysJSON = `{"auth":"1F2Auk0iTJKXjJyiPlMu+w==","p256dh":"BJx6rbJEVu/Juf1xNEk6jO3pTxkyNFGqK1r/zw/iiaEnATH736mYYUSDLFRBsSaIK47vLsVmI+cNraliHyl/8WM="}` + + legacyVAPIDPublicKey = `BEkDdNnpEcD8M4mRGOFJWTDJ4GkDI5Xs3vpIOrAaBZKRCVv6V3sB3CFujTFiD6DHda7W8pCyChJDU205otrbCAw` + + legacyVAPIDPrivateKey = `F0RGqNXLeWLINzn7qIcLsF9lSbRSWgjqUVaoWB6zUqY` + + legacyVAPIDKeyAsJSON = `"-----BEGIN PRIVATE KEY-----\nMIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgF0RGqNXLeWLINzn7\nqIcLsF9lSbRSWgjqUVaoWB6zUqahRANCAARJA3TZ6RHA/DOJkRjhSVkwyeBpAyOV\n7N76SDqwGgWSkQlb+ld7Adwhbo0xYg+gx3Wu1vKQsgoSQ1NtOaLa2wgM\n-----END PRIVATE KEY-----\n"` +) + +func TestLegacySubscriptionKeypair(t *testing.T) { + var keys Keys + err := json.Unmarshal([]byte(legacyKeysJSON), &keys) + if err != nil { + t.Fatal(err) + } + var emptyKeys Keys + if keys.Auth == emptyKeys.Auth { + t.Fatal("failed to deserialize auth key") + } + if keys.P256dh == emptyKeys.P256dh { + t.Fatal("failed to deserialize p256dh key") + } +} + +func TestLegacyVAPIDParsing(t *testing.T) { + // test that we can parse the legacy VAPID private key format (raw bytes + // of the private key as b64) and we get the same keys as the JSON format + vapidKeys, err := DecodeLegacyVAPIDPrivateKey(legacyVAPIDPrivateKey) + if err != nil { + t.Fatal(err) + } + + if vapidKeys.publicKey != legacyVAPIDPublicKey { + t.Fatal("decoded legacy VAPID private key, but did not recover true public key") + } + + vapidKeysFromJSON := new(VAPIDKeys) + if err := json.Unmarshal([]byte(legacyVAPIDKeyAsJSON), vapidKeysFromJSON); err != nil { + t.Fatal(err) + } + if !vapidKeys.privateKey.Equal(vapidKeysFromJSON.privateKey) { + t.Fatal("decoded legacy VAPID private key, but did not recover true private key") + } + if vapidKeys.publicKey != vapidKeysFromJSON.publicKey { + t.Fatal("unexpected private/public key mismatch") + } +} diff --git a/vapid.go b/vapid.go index fe2c580..f4b0b53 100644 --- a/vapid.go +++ b/vapid.go @@ -4,114 +4,174 @@ import ( "crypto/ecdsa" "crypto/elliptic" "crypto/rand" + "crypto/x509" "encoding/base64" + "encoding/json" + "encoding/pem" "fmt" - "math/big" "net/url" + "strings" "time" - "github.com/golang-jwt/jwt" + jwt "github.com/golang-jwt/jwt/v5" ) -// GenerateVAPIDKeys will create a private and public VAPID key pair -func GenerateVAPIDKeys() (privateKey, publicKey string, err error) { - // Get the private key from the P256 curve - curve := elliptic.P256() +// VAPIDKeys is a public-private keypair for use in VAPID. +// It marshals to a JSON string containing the PEM of the PKCS8 +// of the private key. +type VAPIDKeys struct { + privateKey *ecdsa.PrivateKey + publicKey string // raw bytes encoding in urlsafe base64, as per RFC +} + +// PublicKeyString returns the base64url-encoded uncompressed public key of the keypair, +// as defined in RFC8292. +func (v *VAPIDKeys) PublicKeyString() string { + return v.publicKey +} + +// PrivateKey returns the private key of the keypair. +func (v *VAPIDKeys) PrivateKey() *ecdsa.PrivateKey { + return v.privateKey +} + +// Equal compares two VAPIDKeys for equality. +func (v *VAPIDKeys) Equal(o *VAPIDKeys) bool { + return v.privateKey.Equal(o.privateKey) +} - private, x, y, err := elliptic.GenerateKey(curve, rand.Reader) +var _ json.Marshaler = (*VAPIDKeys)(nil) +var _ json.Unmarshaler = (*VAPIDKeys)(nil) + +// MarshalJSON implements json.Marshaler, allowing serialization to JSON. +func (v *VAPIDKeys) MarshalJSON() ([]byte, error) { + pkcs8bytes, err := x509.MarshalPKCS8PrivateKey(v.privateKey) if err != nil { - return + return nil, err } + pemBlock := pem.Block{ + Type: "PRIVATE KEY", + Bytes: pkcs8bytes, + } + pemBytes := pem.EncodeToMemory(&pemBlock) + if pemBytes == nil { + return nil, fmt.Errorf("could not encode VAPID keys as PEM") + } + return json.Marshal(string(pemBytes)) +} - public := elliptic.Marshal(curve, x, y) - - // Convert to base64 - publicKey = base64.RawURLEncoding.EncodeToString(public) - privateKey = base64.RawURLEncoding.EncodeToString(private) +// MarshalJSON implements json.Unmarshaler, allowing deserialization from JSON. +func (v *VAPIDKeys) UnmarshalJSON(b []byte) error { + var pemKey string + if err := json.Unmarshal(b, &pemKey); err != nil { + return err + } + pemBlock, _ := pem.Decode([]byte(pemKey)) + if pemBlock == nil { + return fmt.Errorf("could not decode PEM block with VAPID keys") + } + privKey, err := x509.ParsePKCS8PrivateKey(pemBlock.Bytes) + if err != nil { + return err + } + privateKey, ok := privKey.(*ecdsa.PrivateKey) + if !ok { + return fmt.Errorf("Invalid type of private key %T", privateKey) + } + if privateKey.Curve != elliptic.P256() { + return fmt.Errorf("Invalid curve for private key %v", privateKey.Curve) + } + publicKeyStr, err := makePublicKeyString(privateKey) + if err != nil { + return err // should not be possible since we confirmed P256 already + } - return + // success + v.privateKey = privateKey + v.publicKey = publicKeyStr + return nil } -// Generates the ECDSA public and private keys for the JWT encryption -func generateVAPIDHeaderKeys(privateKey []byte) *ecdsa.PrivateKey { - // Public key - curve := elliptic.P256() - px, py := curve.ScalarMult( - curve.Params().Gx, - curve.Params().Gy, - privateKey, - ) +// GenerateVAPIDKeys generates a VAPID keypair (an ECDSA keypair on +// the P-256 curve). +func GenerateVAPIDKeys() (result *VAPIDKeys, err error) { + private, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return + } - pubKey := ecdsa.PublicKey{ - Curve: curve, - X: px, - Y: py, + pubKeyECDH, err := private.PublicKey.ECDH() + if err != nil { + return } + publicKey := base64.RawURLEncoding.EncodeToString(pubKeyECDH.Bytes()) - // Private key - d := &big.Int{} - d.SetBytes(privateKey) + return &VAPIDKeys{ + privateKey: private, + publicKey: publicKey, + }, nil +} - return &ecdsa.PrivateKey{ - PublicKey: pubKey, - D: d, +// ECDSAToVAPIDKeys wraps an existing ecdsa.PrivateKey in VAPIDKeys for use in +// VAPID header signing. +func ECDSAToVAPIDKeys(privKey *ecdsa.PrivateKey) (result *VAPIDKeys, err error) { + if privKey.Curve != elliptic.P256() { + return nil, fmt.Errorf("Invalid curve for private key %v", privKey.Curve) + } + publicKeyString, err := makePublicKeyString(privKey) + if err != nil { + return nil, err } + return &VAPIDKeys{ + privateKey: privKey, + publicKey: publicKeyString, + }, nil +} + +func makePublicKeyString(privKey *ecdsa.PrivateKey) (result string, err error) { + // to get the raw bytes we have to convert the public key to *ecdh.PublicKey + // this type assertion (from the crypto.PublicKey returned by (*ecdsa.PrivateKey).Public() + // to *ecdsa.PublicKey) cannot fail: + publicKey, err := privKey.Public().(*ecdsa.PublicKey).ECDH() + if err != nil { + return // should not be possible if we confirmed P256 already + } + return base64.RawURLEncoding.EncodeToString(publicKey.Bytes()), nil } // getVAPIDAuthorizationHeader func getVAPIDAuthorizationHeader( - endpoint, - subscriber, - vapidPublicKey, - vapidPrivateKey string, + endpoint string, + subscriber string, + vapidKeys *VAPIDKeys, expiration time.Time, ) (string, error) { + if expiration.IsZero() { + expiration = time.Now().Add(time.Hour * 12) + } + // Create the JWT token subURL, err := url.Parse(endpoint) if err != nil { return "", err } + // Unless subscriber is an HTTPS URL, assume an e-mail address + if !strings.HasPrefix(subscriber, "https:") && !strings.HasPrefix(subscriber, "mailto:") { + subscriber = "mailto:" + subscriber + } + token := jwt.NewWithClaims(jwt.SigningMethodES256, jwt.MapClaims{ - "aud": fmt.Sprintf("%s://%s", subURL.Scheme, subURL.Host), + "aud": subURL.Scheme + "://" + subURL.Host, "exp": expiration.Unix(), - "sub": fmt.Sprintf("mailto:%s", subscriber), + "sub": subscriber, }) - // Decode the VAPID private key - decodedVapidPrivateKey, err := decodeVapidKey(vapidPrivateKey) - if err != nil { - return "", err - } - - privKey := generateVAPIDHeaderKeys(decodedVapidPrivateKey) - // Sign token with private key - jwtString, err := token.SignedString(privKey) - if err != nil { - return "", err - } - - // Decode the VAPID public key - pubKey, err := decodeVapidKey(vapidPublicKey) + jwtString, err := token.SignedString(vapidKeys.privateKey) if err != nil { return "", err } - return fmt.Sprintf( - "vapid t=%s, k=%s", - jwtString, - base64.RawURLEncoding.EncodeToString(pubKey), - ), nil -} - -// Need to decode the vapid private key in multiple base64 formats -// Solution from: https://github.com/SherClockHolmes/webpush-go/issues/29 -func decodeVapidKey(key string) ([]byte, error) { - bytes, err := base64.URLEncoding.DecodeString(key) - if err == nil { - return bytes, nil - } - - return base64.RawURLEncoding.DecodeString(key) + return "vapid t=" + jwtString + ", k=" + vapidKeys.publicKey, nil } diff --git a/vapid_test.go b/vapid_test.go index be4cca8..9f68a14 100644 --- a/vapid_test.go +++ b/vapid_test.go @@ -1,13 +1,16 @@ package webpush import ( - "encoding/base64" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "encoding/json" "fmt" "strings" "testing" "time" - "github.com/golang-jwt/jwt" + jwt "github.com/golang-jwt/jwt/v5" ) func TestVAPID(t *testing.T) { @@ -15,7 +18,7 @@ func TestVAPID(t *testing.T) { sub := "test@test.com" // Generate vapid keys - vapidPrivateKey, vapidPublicKey, err := GenerateVAPIDKeys() + vapidKeys, err := GenerateVAPIDKeys() if err != nil { t.Fatal(err) } @@ -24,8 +27,7 @@ func TestVAPID(t *testing.T) { vapidAuthHeader, err := getVAPIDAuthorizationHeader( s.Endpoint, sub, - vapidPublicKey, - vapidPrivateKey, + vapidKeys, time.Now().Add(time.Hour*12), ) if err != nil { @@ -41,14 +43,7 @@ func TestVAPID(t *testing.T) { } // To decode the token it needs the VAPID public key - b64 := base64.RawURLEncoding - decodedVapidPrivateKey, err := b64.DecodeString(vapidPrivateKey) - if err != nil { - t.Fatal("Could not decode VAPID private key") - } - - privKey := generateVAPIDHeaderKeys(decodedVapidPrivateKey) - return privKey.Public(), nil + return vapidKeys.privateKey.Public(), nil }) // Check the claims on the token @@ -72,17 +67,27 @@ func TestVAPID(t *testing.T) { } func TestVAPIDKeys(t *testing.T) { - privateKey, publicKey, err := GenerateVAPIDKeys() + vapidKeys, err := GenerateVAPIDKeys() + if err != nil { + t.Fatal(err) + } + + j, err := json.Marshal(vapidKeys) if err != nil { t.Fatal(err) } - if len(privateKey) != 43 { - t.Fatal("Generated incorrect VAPID private key") + vapidKeys2 := new(VAPIDKeys) + if err := json.Unmarshal(j, vapidKeys2); err != nil { + t.Fatal(err) + } + + if !vapidKeys.privateKey.Equal(vapidKeys2.privateKey) { + t.Fatalf("could not round-trip private key") } - if len(publicKey) != 87 { - t.Fatal("Generated incorrect VAPID public key") + if vapidKeys.publicKey != vapidKeys2.publicKey { + t.Fatalf("could not round-trip public key") } } @@ -100,3 +105,116 @@ func getTokenFromAuthorizationHeader(tokenHeader string, t *testing.T) string { return tsplit[1][:len(tsplit[1])-1] } + +func Test_ecdhPublicKeyToECDSA(t *testing.T) { + tests := [...]struct { + name string + curve elliptic.Curve + }{ + // P224 not supported by ecdh + { + name: "P256", + curve: elliptic.P256(), + }, + { + name: "P256", + curve: elliptic.P384(), + }, + { + name: "P521", + curve: elliptic.P521(), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pk, err := ecdsa.GenerateKey(tt.curve, rand.Reader) + if err != nil { + t.Fatalf("generating ecdsa.PrivateKey: %s", err) + } + original := &pk.PublicKey + converted, err := original.ECDH() + if err != nil { + t.Fatalf("converting ecdsa.PublicKey to ecdh.PublicKey: %s", err) + } + roundtrip, err := ecdhPublicKeyToECDSA(converted) + if err != nil { + t.Fatalf("converting ecdh.PublicKey back to ecdsa.PublicKey: %s", err) + } + if !roundtrip.Equal(original) { + t.Errorf("Roundtrip changed key from %v to %v", original, roundtrip) + } + }) + } +} + +func Test_ecdhPrivateKeyToECDSA(t *testing.T) { + tests := [...]struct { + name string + curve elliptic.Curve + }{ + // P224 not supported by ecdh + { + name: "P256", + curve: elliptic.P256(), + }, + { + name: "P256", + curve: elliptic.P384(), + }, + { + name: "P521", + curve: elliptic.P521(), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + original, err := ecdsa.GenerateKey(tt.curve, rand.Reader) + if err != nil { + t.Fatalf("generating ecdsa.PrivateKey: %s", err) + } + converted, err := original.ECDH() + if err != nil { + t.Fatalf("converting ecdsa.PrivateKey to ecdh.PrivateKey: %s", err) + } + roundtrip, err := ecdhPrivateKeyToECDSA(converted) + if err != nil { + t.Fatalf("converting ecdh.PrivateKey back to ecdsa.PrivateKey: %s", err) + } + if !roundtrip.Equal(original) { + t.Errorf("Roundtrip changed key from %v to %v", original, roundtrip) + } + }) + } +} + +func TestVAPIDKeyFromECDSA(t *testing.T) { + v, err := GenerateVAPIDKeys() + if err != nil { + t.Fatal(err) + } + privKey := v.PrivateKey() + v2, err := ECDSAToVAPIDKeys(privKey) + if err != nil { + t.Fatal(err) + } + if !v.Equal(v2) { + t.Fatal("ECDSAToVAPIDKeys failed round-trip") + } +} + +func BenchmarkVAPIDSigning(b *testing.B) { + vapidKeys, err := GenerateVAPIDKeys() + if err != nil { + b.Fatal(err) + } + expiration := time.Now().Add(24 * time.Hour) + b.ResetTimer() + for i := 0; i < b.N; i++ { + getVAPIDAuthorizationHeader( + "https://test.push.service/v2/AOWJIDuOMDSo6uNnRXYNsw", + "https://application.server", + vapidKeys, + expiration, + ) + } +} diff --git a/webpush.go b/webpush.go index 4c85ad6..abdc9b1 100644 --- a/webpush.go +++ b/webpush.go @@ -5,12 +5,14 @@ import ( "context" "crypto/aes" "crypto/cipher" - "crypto/elliptic" + "crypto/ecdh" "crypto/rand" "crypto/sha256" "encoding/base64" "encoding/binary" + "encoding/json" "errors" + "fmt" "io" "net/http" "strconv" @@ -22,18 +24,13 @@ import ( const MaxRecordSize uint32 = 4096 -var ErrMaxPadExceeded = errors.New("payload has exceeded the maximum length") +var ( + ErrRecordSizeTooSmall = errors.New("record size too small for message") -// saltFunc generates a salt of 16 bytes -var saltFunc = func() ([]byte, error) { - salt := make([]byte, 16) - _, err := io.ReadFull(rand.Reader, salt) - if err != nil { - return salt, err - } + invalidAuthKeyLength = errors.New("invalid auth key length (must be 16)") - return salt, nil -} + defaultHTTPClient = &http.Client{} +) // HTTPClient is an interface for sending the notification HTTP request / testing type HTTPClient interface { @@ -46,86 +43,173 @@ type Options struct { RecordSize uint32 // Limit the record size Subscriber string // Sub in VAPID JWT token Topic string // Set the Topic header to collapse a pending messages (Optional) - TTL int // Set the TTL on the endpoint POST request + TTL int // Set the TTL on the endpoint POST request, in seconds Urgency Urgency // Set the Urgency header to change a message priority (Optional) - VAPIDPublicKey string // VAPID public key, passed in VAPID Authorization header - VAPIDPrivateKey string // VAPID private key, used to sign VAPID JWT token + VAPIDKeys *VAPIDKeys // VAPID public-private keypair to generate the VAPID Authorization header VapidExpiration time.Time // optional expiration for VAPID JWT token (defaults to now + 12 hours) } -// Keys are the base64 encoded values from PushSubscription.getKey() +// Keys represents a subscription's keys (its ECDH public key on the P-256 curve +// and its 16-byte authentication secret). type Keys struct { + Auth [16]byte + P256dh *ecdh.PublicKey +} + +// Equal compares two Keys for equality. +func (k *Keys) Equal(o Keys) bool { + return k.Auth == o.Auth && k.P256dh.Equal(o.P256dh) +} + +var _ json.Marshaler = (*Keys)(nil) +var _ json.Unmarshaler = (*Keys)(nil) + +type marshaledKeys struct { Auth string `json:"auth"` P256dh string `json:"p256dh"` } +// MarshalJSON implements json.Marshaler, allowing serialization to JSON. +func (k *Keys) MarshalJSON() ([]byte, error) { + m := marshaledKeys{ + Auth: base64.RawStdEncoding.EncodeToString(k.Auth[:]), + P256dh: base64.RawStdEncoding.EncodeToString(k.P256dh.Bytes()), + } + return json.Marshal(&m) +} + +// MarshalJSON implements json.Unmarshaler, allowing deserialization from JSON. +func (k *Keys) UnmarshalJSON(b []byte) (err error) { + var m marshaledKeys + if err := json.Unmarshal(b, &m); err != nil { + return err + } + authBytes, err := decodeSubscriptionKey(m.Auth) + if err != nil { + return err + } + if len(authBytes) != 16 { + return fmt.Errorf("invalid auth bytes length %d (must be 16)", len(authBytes)) + } + copy(k.Auth[:], authBytes) + rawDHKey, err := decodeSubscriptionKey(m.P256dh) + if err != nil { + return err + } + k.P256dh, err = ecdh.P256().NewPublicKey(rawDHKey) + return err +} + +// DecodeSubscriptionKeys decodes and validates a base64-encoded pair of subscription keys +// (the authentication secret and ECDH public key). +func DecodeSubscriptionKeys(auth, p256dh string) (keys Keys, err error) { + authBytes, err := decodeSubscriptionKey(auth) + if err != nil { + return + } + if len(authBytes) != 16 { + err = invalidAuthKeyLength + return + } + copy(keys.Auth[:], authBytes) + dhBytes, err := decodeSubscriptionKey(p256dh) + if err != nil { + return + } + keys.P256dh, err = ecdh.P256().NewPublicKey(dhBytes) + if err != nil { + return + } + return +} + // Subscription represents a PushSubscription object from the Push API type Subscription struct { Endpoint string `json:"endpoint"` Keys Keys `json:"keys"` } -// SendNotification calls SendNotificationWithContext with default context for backwards-compatibility -func SendNotification(message []byte, s *Subscription, options *Options) (*http.Response, error) { - return SendNotificationWithContext(context.Background(), message, s, options) -} - -// SendNotificationWithContext sends a push notification to a subscription's endpoint -// Message Encryption for Web Push, and VAPID protocols. -// FOR MORE INFORMATION SEE RFC8291: https://datatracker.ietf.org/doc/rfc8291 -func SendNotificationWithContext(ctx context.Context, message []byte, s *Subscription, options *Options) (*http.Response, error) { - // Authentication secret (auth_secret) - authSecret, err := decodeSubscriptionKey(s.Keys.Auth) +// SendNotification sends a push notification to a subscription's endpoint, +// applying encryption (RFC 8291) and adding a VAPID header (RFC 8292). +func SendNotification(ctx context.Context, message []byte, s *Subscription, options *Options) (*http.Response, error) { + // Compose message body (RFC8291 encryption of the message) + body, err := EncryptNotification(message, s.Keys, options.RecordSize) if err != nil { return nil, err } - // dh (Diffie Hellman) - dh, err := decodeSubscriptionKey(s.Keys.P256dh) + // Get VAPID Authorization header + vapidAuthHeader, err := getVAPIDAuthorizationHeader( + s.Endpoint, + options.Subscriber, + options.VAPIDKeys, + options.VapidExpiration, + ) if err != nil { return nil, err } - // Generate 16 byte salt - salt, err := saltFunc() - if err != nil { - return nil, err + // Compose and send the HTTP request + return sendNotification(ctx, s.Endpoint, options, vapidAuthHeader, body) +} + +// EncryptNotification implements the encryption algorithm specified by RFC 8291 for web push +// (RFC 8188's aes128gcm content-encoding, with the key material derived from +// elliptic curve Diffie-Hellman over the P-256 curve). +func EncryptNotification(message []byte, keys Keys, recordSize uint32) ([]byte, error) { + // Get the record size + if recordSize == 0 { + recordSize = MaxRecordSize + } else if recordSize < 128 { + return nil, ErrRecordSizeTooSmall } - // Create the ecdh_secret shared key pair - curve := elliptic.P256() + // Allocate buffer to hold the eventual message + // [ header block ] [ ciphertext ] [ 16 byte AEAD tag ], totaling RecordSize bytes + // the ciphertext is the encryption of: [ message ] [ \x02 ] [ 0 or more \x00 as needed ] + recordBuf := make([]byte, recordSize) + // remainingBuf tracks our current writing position in recordBuf: + remainingBuf := recordBuf // Application server key pairs (single use) - localPrivateKey, x, y, err := elliptic.GenerateKey(curve, rand.Reader) + localPrivateKey, err := ecdh.P256().GenerateKey(rand.Reader) if err != nil { return nil, err } + localPublicKey := localPrivateKey.PublicKey() - localPublicKey := elliptic.Marshal(curve, x, y) - - // Combine application keys with receiver's EC public key - sharedX, sharedY := elliptic.Unmarshal(curve, dh) - if sharedX == nil { - return nil, errors.New("Unmarshal Error: Public key is not a valid point on the curve") + // Encryption Content-Coding Header + // +-----------+--------+-----------+---------------+ + // | salt (16) | rs (4) | idlen (1) | keyid (idlen) | + // +-----------+--------+-----------+---------------+ + // in our case the keyid is localPublicKey.Bytes(), so 65 bytes + // First, generate the salt + _, err = rand.Read(remainingBuf[:16]) + if err != nil { + return nil, err } - - // Derive ECDH shared secret - sx, sy := curve.ScalarMult(sharedX, sharedY, localPrivateKey) - if !curve.IsOnCurve(sx, sy) { - return nil, errors.New("Encryption error: ECDH shared secret isn't on curve") + salt := remainingBuf[:16] + remainingBuf = remainingBuf[16:] + binary.BigEndian.PutUint32(remainingBuf[:], recordSize) + remainingBuf = remainingBuf[4:] + localPublicKeyBytes := localPublicKey.Bytes() + remainingBuf[0] = byte(len(localPublicKeyBytes)) + remainingBuf = remainingBuf[1:] + copy(remainingBuf[:], localPublicKeyBytes) + remainingBuf = remainingBuf[len(localPublicKeyBytes):] + + // Combine application keys with receiver's EC public key to derive ECDH shared secret + sharedECDHSecret, err := localPrivateKey.ECDH(keys.P256dh) + if err != nil { + return nil, fmt.Errorf("deriving shared secret: %w", err) } - mlen := curve.Params().BitSize / 8 - sharedECDHSecret := make([]byte, mlen) - sx.FillBytes(sharedECDHSecret) - - hash := sha256.New // ikm prkInfoBuf := bytes.NewBuffer([]byte("WebPush: info\x00")) - prkInfoBuf.Write(dh) - prkInfoBuf.Write(localPublicKey) + prkInfoBuf.Write(keys.P256dh.Bytes()) + prkInfoBuf.Write(localPublicKey.Bytes()) - prkHKDF := hkdf.New(hash, sharedECDHSecret, authSecret, prkInfoBuf.Bytes()) + prkHKDF := hkdf.New(sha256.New, sharedECDHSecret, keys.Auth[:], prkInfoBuf.Bytes()) ikm, err := getHKDFKey(prkHKDF, 32) if err != nil { return nil, err @@ -133,7 +217,7 @@ func SendNotificationWithContext(ctx context.Context, message []byte, s *Subscri // Derive Content Encryption Key contentEncryptionKeyInfo := []byte("Content-Encoding: aes128gcm\x00") - contentHKDF := hkdf.New(hash, ikm, salt, contentEncryptionKeyInfo) + contentHKDF := hkdf.New(sha256.New, ikm, salt, contentEncryptionKeyInfo) contentEncryptionKey, err := getHKDFKey(contentHKDF, 16) if err != nil { return nil, err @@ -141,7 +225,7 @@ func SendNotificationWithContext(ctx context.Context, message []byte, s *Subscri // Derive the Nonce nonceInfo := []byte("Content-Encoding: nonce\x00") - nonceHKDF := hkdf.New(hash, ikm, salt, nonceInfo) + nonceHKDF := hkdf.New(sha256.New, ikm, salt, nonceInfo) nonce, err := getHKDFKey(nonceHKDF, 12) if err != nil { return nil, err @@ -152,46 +236,37 @@ func SendNotificationWithContext(ctx context.Context, message []byte, s *Subscri if err != nil { return nil, err } - gcm, err := cipher.NewGCM(c) if err != nil { return nil, err } - // Get the record size - recordSize := options.RecordSize - if recordSize == 0 { - recordSize = MaxRecordSize - } - - recordLength := int(recordSize) - 16 - - // Encryption Content-Coding Header - recordBuf := bytes.NewBuffer(salt) - - rs := make([]byte, 4) - binary.BigEndian.PutUint32(rs, recordSize) - - recordBuf.Write(rs) - recordBuf.Write([]byte{byte(len(localPublicKey))}) - recordBuf.Write(localPublicKey) - - // Data - dataBuf := bytes.NewBuffer(message) - - // Pad content to max record size - 16 - header - // Padding ending delimeter - dataBuf.Write([]byte("\x02")) - if err := pad(dataBuf, recordLength-recordBuf.Len()); err != nil { - return nil, err + // need 1 byte for the 0x02 delimiter, 16 bytes for the AEAD tag + if len(remainingBuf) < len(message)+17 { + return nil, ErrRecordSizeTooSmall } + // Copy the message plaintext into the buffer + copy(remainingBuf[:], message[:]) + // The plaintext to be encrypted will include the padding delimiter and the padding; + // cut off the final 16 bytes that are reserved for the AEAD tag + plaintext := remainingBuf[:len(remainingBuf)-16] + remainingBuf = remainingBuf[len(message):] + // Add padding delimiter + remainingBuf[0] = '\x02' + remainingBuf = remainingBuf[1:] + // The rest of the buffer is already zero-padded + + // Encipher the plaintext in place, then add the AEAD tag at the end. + // "To reuse plaintext's storage for the encrypted output, use plaintext[:0] + // as dst. Otherwise, the remaining capacity of dst must not overlap plaintext." + gcm.Seal(plaintext[:0], nonce, plaintext, nil) + + return recordBuf, nil +} - // Compose the ciphertext - ciphertext := gcm.Seal([]byte{}, nonce, dataBuf.Bytes(), nil) - recordBuf.Write(ciphertext) - +func sendNotification(ctx context.Context, endpoint string, options *Options, vapidAuthHeader string, body []byte) (*http.Response, error) { // POST request - req, err := http.NewRequest("POST", s.Endpoint, recordBuf) + req, err := http.NewRequest("POST", endpoint, bytes.NewBuffer(body)) if err != nil { return nil, err } @@ -201,7 +276,6 @@ func SendNotificationWithContext(ctx context.Context, message []byte, s *Subscri } req.Header.Set("Content-Encoding", "aes128gcm") - req.Header.Set("Content-Length", strconv.Itoa(len(ciphertext))) req.Header.Set("Content-Type", "application/octet-stream") req.Header.Set("TTL", strconv.Itoa(options.TTL)) @@ -214,23 +288,6 @@ func SendNotificationWithContext(ctx context.Context, message []byte, s *Subscri req.Header.Set("Urgency", string(options.Urgency)) } - expiration := options.VapidExpiration - if expiration.IsZero() { - expiration = time.Now().Add(time.Hour * 12) - } - - // Get VAPID Authorization header - vapidAuthHeader, err := getVAPIDAuthorizationHeader( - s.Endpoint, - options.Subscriber, - options.VAPIDPublicKey, - options.VAPIDPrivateKey, - expiration, - ) - if err != nil { - return nil, err - } - req.Header.Set("Authorization", vapidAuthHeader) // Send the request @@ -238,27 +295,20 @@ func SendNotificationWithContext(ctx context.Context, message []byte, s *Subscri if options.HTTPClient != nil { client = options.HTTPClient } else { - client = &http.Client{} + client = defaultHTTPClient } return client.Do(req) } // decodeSubscriptionKey decodes a base64 subscription key. -// if necessary, add "=" padding to the key for URL decode func decodeSubscriptionKey(key string) ([]byte, error) { - // "=" padding - buf := bytes.NewBufferString(key) - if rem := len(key) % 4; rem != 0 { - buf.WriteString(strings.Repeat("=", 4-rem)) - } + key = strings.TrimRight(key, "=") - bytes, err := base64.StdEncoding.DecodeString(buf.String()) - if err == nil { - return bytes, nil + if strings.IndexByte(key, '+') != -1 || strings.IndexByte(key, '/') != -1 { + return base64.RawStdEncoding.DecodeString(key) } - - return base64.URLEncoding.DecodeString(buf.String()) + return base64.RawURLEncoding.DecodeString(key) } // Returns a key of length "length" given an hkdf function @@ -271,17 +321,3 @@ func getHKDFKey(hkdf io.Reader, length int) ([]byte, error) { return key, nil } - -func pad(payload *bytes.Buffer, maxPadLen int) error { - payloadLen := payload.Len() - if payloadLen > maxPadLen { - return ErrMaxPadExceeded - } - - padLen := maxPadLen - payloadLen - - padding := make([]byte, padLen) - payload.Write(padding) - - return nil -} diff --git a/webpush_test.go b/webpush_test.go index 807a1f7..3633b3c 100644 --- a/webpush_test.go +++ b/webpush_test.go @@ -1,6 +1,8 @@ package webpush import ( + "context" + "encoding/json" "net/http" "strings" "testing" @@ -13,35 +15,48 @@ func (*testHTTPClient) Do(*http.Request) (*http.Response, error) { } func getURLEncodedTestSubscription() *Subscription { - return &Subscription{ - Endpoint: "https://updates.push.services.mozilla.com/wpush/v2/gAAAAA", - Keys: Keys{ - P256dh: "BNNL5ZaTfK81qhXOx23-wewhigUeFb632jN6LvRWCFH1ubQr77FE_9qV1FuojuRmHP42zmf34rXgW80OvUVDgTk", - Auth: "zqbxT6JKstKSY9JKibZLSQ", - }, + subJson := `{ + "endpoint": "https://updates.push.services.mozilla.com/wpush/v2/gAAAAA", + "keys": { + "p256dh": "BNNL5ZaTfK81qhXOx23-wewhigUeFb632jN6LvRWCFH1ubQr77FE_9qV1FuojuRmHP42zmf34rXgW80OvUVDgTk", + "auth": "zqbxT6JKstKSY9JKibZLSQ" + } + }` + sub := new(Subscription) + if err := json.Unmarshal([]byte(subJson), sub); err != nil { + panic(err) } + return sub } func getStandardEncodedTestSubscription() *Subscription { - return &Subscription{ - Endpoint: "https://updates.push.services.mozilla.com/wpush/v2/gAAAAA", - Keys: Keys{ - P256dh: "BNNL5ZaTfK81qhXOx23+wewhigUeFb632jN6LvRWCFH1ubQr77FE/9qV1FuojuRmHP42zmf34rXgW80OvUVDgTk=", - Auth: "zqbxT6JKstKSY9JKibZLSQ==", - }, + subJson := `{ + "endpoint": "https://updates.push.services.mozilla.com/wpush/v2/gAAAAA", + "keys": { + "p256dh": "BNNL5ZaTfK81qhXOx23+wewhigUeFb632jN6LvRWCFH1ubQr77FE/9qV1FuojuRmHP42zmf34rXgW80OvUVDgTk=", + "auth": "zqbxT6JKstKSY9JKibZLSQ==" + } + }` + sub := new(Subscription) + if err := json.Unmarshal([]byte(subJson), sub); err != nil { + panic(err) } + return sub } func TestSendNotificationToURLEncodedSubscription(t *testing.T) { - resp, err := SendNotification([]byte("Test"), getURLEncodedTestSubscription(), &Options{ - HTTPClient: &testHTTPClient{}, - RecordSize: 3070, - Subscriber: "", - Topic: "test_topic", - TTL: 0, - Urgency: "low", - VAPIDPublicKey: "test-public", - VAPIDPrivateKey: "test-private", + vapidKeys, err := GenerateVAPIDKeys() + if err != nil { + t.Fatal(err) + } + resp, err := SendNotification(context.Background(), []byte("Test"), getURLEncodedTestSubscription(), &Options{ + HTTPClient: &testHTTPClient{}, + RecordSize: 3070, + Subscriber: "", + Topic: "test_topic", + TTL: 0, + Urgency: "low", + VAPIDKeys: vapidKeys, }) if err != nil { t.Fatal(err) @@ -49,7 +64,7 @@ func TestSendNotificationToURLEncodedSubscription(t *testing.T) { if resp.StatusCode != 201 { t.Fatalf( - "Incorreect status code, expected=%d, got=%d", + "Incorrect status code, expected=%d, got=%d", resp.StatusCode, 201, ) @@ -57,13 +72,17 @@ func TestSendNotificationToURLEncodedSubscription(t *testing.T) { } func TestSendNotificationToStandardEncodedSubscription(t *testing.T) { - resp, err := SendNotification([]byte("Test"), getStandardEncodedTestSubscription(), &Options{ - HTTPClient: &testHTTPClient{}, - Subscriber: "", - Topic: "test_topic", - TTL: 0, - Urgency: "low", - VAPIDPrivateKey: "testKey", + vapidKeys, err := GenerateVAPIDKeys() + if err != nil { + t.Fatal(err) + } + resp, err := SendNotification(context.Background(), []byte("Test"), getStandardEncodedTestSubscription(), &Options{ + HTTPClient: &testHTTPClient{}, + Subscriber: "", + Topic: "test_topic", + TTL: 0, + Urgency: "low", + VAPIDKeys: vapidKeys, }) if err != nil { t.Fatal(err) @@ -79,15 +98,39 @@ func TestSendNotificationToStandardEncodedSubscription(t *testing.T) { } func TestSendTooLargeNotification(t *testing.T) { - _, err := SendNotification([]byte(strings.Repeat("Test", int(MaxRecordSize))), getStandardEncodedTestSubscription(), &Options{ - HTTPClient: &testHTTPClient{}, - Subscriber: "", - Topic: "test_topic", - TTL: 0, - Urgency: "low", - VAPIDPrivateKey: "testKey", + _, err := SendNotification(context.Background(), []byte(strings.Repeat("Test", int(MaxRecordSize))), getStandardEncodedTestSubscription(), &Options{ + HTTPClient: &testHTTPClient{}, + Subscriber: "", + Topic: "test_topic", + TTL: 0, + Urgency: "low", }) if err == nil { - t.Fatalf("Error is nil, expected=%s", ErrMaxPadExceeded) + t.Fatalf("Error is nil, expected=%s", ErrRecordSizeTooSmall) + } +} + +func BenchmarkWebPush(b *testing.B) { + vapidKeys, err := GenerateVAPIDKeys() + if err != nil { + b.Fatal(err) + } + ctx := context.Background() + message := []byte("@time=2024-12-26T19:36:21.923Z;account=shivaram;msgid=56g9v3b92q6q4wtq43uhyqzegw :shivaram!~u@kca7nfgniet7q.irc PRIVMSG #redacted :[redacted message contents]") + sub := getStandardEncodedTestSubscription() + options := Options{ + HTTPClient: &testHTTPClient{}, + RecordSize: 2048, + Subscriber: "https://example.com", + TTL: 60 * 60 * 24, + Urgency: UrgencyHigh, + VAPIDKeys: vapidKeys, + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + if _, err := SendNotification(ctx, message, sub, &options); err != nil { + b.Fatal(err) + } } }