Skip to content

Commit c20ca66

Browse files
authored
Fix Connect requiring signing key in dev mode (#114)
1 parent d414580 commit c20ca66

File tree

12 files changed

+165
-45
lines changed

12 files changed

+165
-45
lines changed

client.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,10 @@ func NewClient(opts ClientOpts) (Client, error) {
109109
return nil, err
110110
}
111111

112+
if opts.Logger == nil {
113+
opts.Logger = slog.Default()
114+
}
115+
112116
c := &apiClient{
113117
ClientOpts: opts,
114118
}

connect.go

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -79,18 +79,21 @@ func Connect(ctx context.Context, opts ConnectOpts) (connect.WorkerConnection, e
7979
return nil, fmt.Errorf("invalid handler passed")
8080
}
8181

82+
var hashedKey []byte
83+
var hashedFallbackKey []byte
8284
signingKey := defaultClient.h.GetSigningKey()
8385
if signingKey == "" {
84-
return nil, fmt.Errorf("signing key is required")
85-
}
86-
87-
hashedKey, err := hashedSigningKey([]byte(signingKey))
88-
if err != nil {
89-
return nil, fmt.Errorf("failed to hash signing key: %w", err)
90-
}
86+
if !defaultClient.h.isDev() {
87+
// Signing key is only required in cloud mode.
88+
return nil, fmt.Errorf("signing key is required")
89+
}
90+
} else {
91+
var err error
92+
hashedKey, err = hashedSigningKey([]byte(signingKey))
93+
if err != nil {
94+
return nil, fmt.Errorf("failed to hash signing key: %w", err)
95+
}
9196

92-
var hashedFallbackKey []byte
93-
{
9497
if fallbackKey := defaultClient.h.GetSigningKeyFallback(); fallbackKey != "" {
9598
hashedFallbackKey, err = hashedSigningKey([]byte(fallbackKey))
9699
if err != nil {

connect/handler.go

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,14 @@ import (
55
"encoding/json"
66
"errors"
77
"fmt"
8+
"io"
9+
"log/slog"
10+
"net/url"
11+
"os"
12+
"runtime"
13+
"sync/atomic"
14+
"time"
15+
816
"github.com/coder/websocket"
917
"github.com/inngest/inngest/pkg/execution/state"
1018
"github.com/inngest/inngest/pkg/sdk"
@@ -13,13 +21,6 @@ import (
1321
"github.com/inngest/inngestgo/internal/sdkrequest"
1422
"github.com/pbnjay/memory"
1523
"golang.org/x/sync/errgroup"
16-
"io"
17-
"log/slog"
18-
"net/url"
19-
"os"
20-
"runtime"
21-
"sync/atomic"
22-
"time"
2324
)
2425

2526
const (
@@ -159,7 +160,7 @@ type authContext struct {
159160

160161
func (h *connectHandler) Connect(ctx context.Context) (WorkerConnection, error) {
161162
signingKey := h.opts.HashedSigningKey
162-
if len(signingKey) == 0 {
163+
if len(signingKey) == 0 && !h.opts.IsDev {
163164
return nil, fmt.Errorf("hashed signing key is required")
164165
}
165166

connect/workerapi.go

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@ import (
44
"bytes"
55
"context"
66
"fmt"
7-
"github.com/inngest/inngest/proto/gen/connect/v1"
8-
"google.golang.org/protobuf/proto"
97
"io"
108
"log/slog"
119
"net/http"
10+
11+
"github.com/inngest/inngest/proto/gen/connect/v1"
12+
"google.golang.org/protobuf/proto"
1213
)
1314

1415
type workerApiClient struct {
@@ -36,7 +37,9 @@ func (a *workerApiClient) start(ctx context.Context, hashedSigningKey []byte, re
3637
}
3738

3839
httpReq.Header.Set("Content-Type", "application/protobuf")
39-
httpReq.Header.Set("Authorization", fmt.Sprintf("Bearer %s", string(hashedSigningKey)))
40+
if hashedSigningKey != nil {
41+
httpReq.Header.Set("Authorization", fmt.Sprintf("Bearer %s", string(hashedSigningKey)))
42+
}
4043

4144
if a.env != nil {
4245
httpReq.Header.Add("X-Inngest-Env", *a.env)
@@ -96,7 +99,9 @@ func (a *workerApiClient) sendBufferedMessage(ctx context.Context, hashedSigning
9699
}
97100

98101
httpReq.Header.Set("Content-Type", "application/protobuf")
99-
httpReq.Header.Set("Authorization", fmt.Sprintf("Bearer %s", string(hashedSigningKey)))
102+
if hashedSigningKey != nil {
103+
httpReq.Header.Set("Authorization", fmt.Sprintf("Bearer %s", string(hashedSigningKey)))
104+
}
100105

101106
if a.env != nil {
102107
httpReq.Header.Add("X-Inngest-Env", *a.env)

handler.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -136,16 +136,16 @@ func (h handlerOpts) GetSigningKeyFallback() string {
136136

137137
// GetAPIOrigin returns the host to use for sending API requests
138138
func (h handlerOpts) GetAPIBaseURL() string {
139-
if h.isDev() {
140-
return DevServerURL()
141-
}
142-
143139
if h.APIBaseURL == nil {
144140
base := os.Getenv("INNGEST_API_BASE_URL")
145141
if base != "" {
146142
return base
147143
}
148144

145+
if h.isDev() {
146+
return DevServerURL()
147+
}
148+
149149
return defaultAPIOrigin
150150
}
151151

handler_test.go

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"github.com/inngest/inngest/pkg/syscode"
2222
"github.com/inngest/inngestgo/internal/sdkrequest"
2323
"github.com/inngest/inngestgo/step"
24+
"github.com/stretchr/testify/assert"
2425
"github.com/stretchr/testify/require"
2526
)
2627

@@ -1012,6 +1013,120 @@ func TestInBandSync(t *testing.T) {
10121013
})
10131014
}
10141015

1016+
func TestConnectSync(t *testing.T) {
1017+
t.Run("cloud", func(t *testing.T) {
1018+
// SDK sends an Authorization header when in cloud mode.
1019+
1020+
r := require.New(t)
1021+
1022+
// We need a cancellable context to stop the Connect worker.
1023+
connectCtx, cancelConnectCtx := context.WithCancel(context.Background())
1024+
defer cancelConnectCtx()
1025+
1026+
headers := http.Header{}
1027+
server := httptest.NewServer(http.HandlerFunc(
1028+
func(w http.ResponseWriter, r *http.Request) {
1029+
if r.URL.Path == "/v0/connect/start" {
1030+
for k, v := range r.Header {
1031+
headers.Add(k, v[0])
1032+
}
1033+
cancelConnectCtx()
1034+
w.WriteHeader(http.StatusOK)
1035+
return
1036+
}
1037+
1038+
w.WriteHeader(http.StatusNotFound)
1039+
_, _ = w.Write([]byte(`{}`))
1040+
},
1041+
))
1042+
defer server.Close()
1043+
1044+
c, err := NewClient(ClientOpts{
1045+
APIBaseURL: toPtr(server.URL),
1046+
AppID: "app",
1047+
SigningKey: toPtr(testKey),
1048+
})
1049+
r.NoError(err)
1050+
1051+
_, err = CreateFunction(
1052+
c,
1053+
FunctionOpts{ID: "fn"},
1054+
EventTrigger("event", nil),
1055+
func(ctx context.Context, input Input[any]) (any, error) {
1056+
return nil, nil
1057+
},
1058+
)
1059+
r.NoError(err)
1060+
1061+
_, _ = Connect(connectCtx, ConnectOpts{
1062+
Apps: []Client{c},
1063+
InstanceID: toPtr("instance"),
1064+
})
1065+
1066+
r.EventuallyWithT(func(t *assert.CollectT) {
1067+
a := assert.New(t)
1068+
a.NotEmpty(headers.Get("Authorization"))
1069+
}, 5*time.Second, 10*time.Millisecond)
1070+
})
1071+
1072+
t.Run("dev", func(t *testing.T) {
1073+
// SDK doesn't send an Authorization header when in dev mode.
1074+
1075+
r := require.New(t)
1076+
1077+
// We need a cancellable context to stop the Connect worker.
1078+
connectCtx, cancelConnectCtx := context.WithCancel(context.Background())
1079+
defer cancelConnectCtx()
1080+
1081+
headers := http.Header{}
1082+
called := false
1083+
server := httptest.NewServer(http.HandlerFunc(
1084+
func(w http.ResponseWriter, r *http.Request) {
1085+
if r.URL.Path == "/v0/connect/start" {
1086+
for k, v := range r.Header {
1087+
headers.Add(k, v[0])
1088+
}
1089+
called = true
1090+
cancelConnectCtx()
1091+
w.WriteHeader(http.StatusOK)
1092+
return
1093+
}
1094+
1095+
w.WriteHeader(http.StatusNotFound)
1096+
},
1097+
))
1098+
defer server.Close()
1099+
1100+
c, err := NewClient(ClientOpts{
1101+
APIBaseURL: toPtr(server.URL),
1102+
AppID: "app",
1103+
Dev: toPtr(true),
1104+
})
1105+
r.NoError(err)
1106+
1107+
_, err = CreateFunction(
1108+
c,
1109+
FunctionOpts{ID: "fn"},
1110+
EventTrigger("event", nil),
1111+
func(ctx context.Context, input Input[any]) (any, error) {
1112+
return nil, nil
1113+
},
1114+
)
1115+
r.NoError(err)
1116+
1117+
_, _ = Connect(connectCtx, ConnectOpts{
1118+
Apps: []Client{c},
1119+
InstanceID: toPtr("instance"),
1120+
})
1121+
1122+
r.EventuallyWithT(func(t *assert.CollectT) {
1123+
a := assert.New(t)
1124+
a.Empty(headers.Get("Authorization"))
1125+
a.True(called)
1126+
}, 5*time.Second, 10*time.Millisecond)
1127+
})
1128+
}
1129+
10151130
func createRequest(t *testing.T, evt any) *sdkrequest.Request {
10161131
t.Helper()
10171132

tests/invoke_test.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,7 @@ import (
1515
)
1616

1717
func TestInvoke(t *testing.T) {
18-
if testing.Short() {
19-
t.Skip()
20-
}
18+
devEnv(t)
2119

2220
t.Run("success", func(t *testing.T) {
2321
ctx := context.Background()

tests/main_test.go

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,6 @@ func TestMain(m *testing.M) {
2929
}
3030

3131
func setup() (func() error, error) {
32-
os.Setenv("INNGEST_DEV", "1")
33-
34-
if os.Getenv("DEV_SERVER_ENABLED") == "0" {
35-
// Don't start the Dev Server.
36-
return func() error { return nil }, nil
37-
}
38-
3932
stopDevServer, err := startDevServer()
4033
if err != nil {
4134
return nil, err
@@ -44,6 +37,11 @@ func setup() (func() error, error) {
4437
return stopDevServer, nil
4538
}
4639

40+
func devEnv(t *testing.T) {
41+
t.Helper()
42+
t.Setenv("INNGEST_DEV", "1")
43+
}
44+
4745
func startDevServer() (func() error, error) {
4846
fmt.Println("Starting Dev Server")
4947
cmd := exec.Command(

tests/parallel_test.go

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,7 @@ import (
1313
)
1414

1515
func TestParallel(t *testing.T) {
16-
if testing.Short() {
17-
t.Skip()
18-
}
19-
16+
devEnv(t)
2017
t.Run("successful with a mix of step kinds", func(t *testing.T) {
2118
ctx := context.Background()
2219
r := require.New(t)

tests/probe_test.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ import (
1616
const sKey = "signkey-prod-000000"
1717

1818
func TestTrustProbe(t *testing.T) {
19+
devEnv(t)
20+
1921
t.Run("dev mode", func(t *testing.T) {
2022
isDev := true
2123

0 commit comments

Comments
 (0)