Skip to content

Commit e812fa8

Browse files
authored
feat: graceful shutdown and context helpers (#9)
* feat: graceful shutdown * docs: add context helpers guide * test: skip SIGTERM on windows
1 parent ae196ad commit e812fa8

File tree

4 files changed

+332
-7
lines changed

4 files changed

+332
-7
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ See [Architecture Guide](docs/architecture.md) for details.
136136
- [Error Handling](docs/guides/error-handling.md) — Error patterns and Problem Details
137137
- [Validation](docs/guides/validation.md) — Input validation patterns
138138
- [Authentication](docs/guides/authentication.md) — Auth middleware and guards
139+
- [Context Helpers](docs/guides/context-helpers.md) — Typed context keys and helper functions
139140
- [Testing](docs/guides/testing.md) — Testing patterns
140141
- [Comparison](docs/guides/comparison.md) — vs Wire, Fx, and others
141142

docs/guides/context-helpers.md

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# Context Helpers
2+
3+
Go doesn't use decorators. The idiomatic pattern is to attach request-scoped data to `context.Context` using typed keys, then provide small helper functions for setting and retrieving values.
4+
5+
This guide shows the pattern modkit recommends for context helpers, which you can use anywhere in your middleware or handlers.
6+
7+
## Typed Context Keys
8+
9+
Use an unexported key type to avoid collisions with other packages.
10+
11+
```go
12+
package auth
13+
14+
type userKey struct{}
15+
16+
var userKeyInstance = userKey{}
17+
```
18+
19+
Keeping the key type and value unexported prevents accidental use from other packages and makes collisions impossible.
20+
21+
## Helper Functions
22+
23+
Wrap `context.WithValue` and `ctx.Value` in helpers so your handlers stay type-safe and readable.
24+
25+
```go
26+
package auth
27+
28+
import "context"
29+
30+
type User struct {
31+
ID string
32+
Email string
33+
Role string
34+
}
35+
36+
func WithUser(ctx context.Context, user *User) context.Context {
37+
return context.WithValue(ctx, userKeyInstance, user)
38+
}
39+
40+
func UserFromContext(ctx context.Context) (*User, bool) {
41+
user, ok := ctx.Value(userKeyInstance).(*User)
42+
return user, ok
43+
}
44+
```
45+
46+
### Using in Middleware
47+
48+
```go
49+
func AuthMiddleware(next http.Handler) http.Handler {
50+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
51+
user, err := authenticate(r)
52+
if err != nil {
53+
http.Error(w, "unauthorized", http.StatusUnauthorized)
54+
return
55+
}
56+
57+
ctx := auth.WithUser(r.Context(), user)
58+
next.ServeHTTP(w, r.WithContext(ctx))
59+
})
60+
}
61+
```
62+
63+
### Using in Handlers
64+
65+
```go
66+
func (c *UsersController) Me(w http.ResponseWriter, r *http.Request) {
67+
user, ok := auth.UserFromContext(r.Context())
68+
if !ok {
69+
http.Error(w, "unauthorized", http.StatusUnauthorized)
70+
return
71+
}
72+
73+
json.NewEncoder(w).Encode(user)
74+
}
75+
```
76+
77+
## Best Practices
78+
79+
- Keep context keys unexported in the package that defines them.
80+
- Prefer helper functions instead of calling `ctx.Value` directly.
81+
- Return `nil`/`false` when the value is missing rather than panicking.
82+
- Treat `context.Context` as request-scoped data only, not as a general dependency container.
83+
84+
Note: some existing modkit docs use exported context keys for brevity. In production code, unexported key types are the recommended practice to avoid collisions.
85+
86+
## Multiple Values
87+
88+
Use a separate key type and helpers for each value you need to store.
89+
90+
```go
91+
package requestid
92+
93+
import "context"
94+
95+
type requestIDKey struct{}
96+
97+
var requestIDKeyInstance = requestIDKey{}
98+
99+
func WithRequestID(ctx context.Context, id string) context.Context {
100+
return context.WithValue(ctx, requestIDKeyInstance, id)
101+
}
102+
103+
func RequestIDFromContext(ctx context.Context) (string, bool) {
104+
id, ok := ctx.Value(requestIDKeyInstance).(string)
105+
return id, ok
106+
}
107+
```
108+
109+
```go
110+
package tenant
111+
112+
import "context"
113+
114+
type tenantKey struct{}
115+
116+
var tenantKeyInstance = tenantKey{}
117+
118+
func WithTenant(ctx context.Context, id string) context.Context {
119+
return context.WithValue(ctx, tenantKeyInstance, id)
120+
}
121+
122+
func TenantFromContext(ctx context.Context) (string, bool) {
123+
id, ok := ctx.Value(tenantKeyInstance).(string)
124+
return id, ok
125+
}
126+
```
127+
128+
This keeps each value isolated and avoids type assertions scattered throughout your handlers.

modkit/http/server.go

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,64 @@
11
package http
22

3-
import "net/http"
3+
import (
4+
"context"
5+
"net/http"
6+
"os"
7+
"os/signal"
8+
"syscall"
9+
"time"
10+
)
411

5-
var listenAndServe = http.ListenAndServe
12+
// ShutdownTimeout controls how long the server will wait for in-flight requests
13+
// to finish after receiving a shutdown signal.
14+
var ShutdownTimeout = 30 * time.Second
15+
16+
var listenAndServe = func(server *http.Server) error {
17+
return server.ListenAndServe()
18+
}
19+
20+
var shutdownServer = func(ctx context.Context, server *http.Server) error {
21+
return server.Shutdown(ctx)
22+
}
623

724
// Serve starts an HTTP server on the given address using the provided handler.
25+
// It handles SIGINT and SIGTERM for graceful shutdown.
826
func Serve(addr string, handler http.Handler) error {
9-
return listenAndServe(addr, handler)
27+
server := &http.Server{
28+
Addr: addr,
29+
Handler: handler,
30+
}
31+
32+
sigCh := make(chan os.Signal, 1)
33+
signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM)
34+
defer func() {
35+
signal.Stop(sigCh)
36+
close(sigCh)
37+
}()
38+
39+
errCh := make(chan error, 1)
40+
go func() {
41+
errCh <- listenAndServe(server)
42+
}()
43+
44+
select {
45+
case err := <-errCh:
46+
if err == http.ErrServerClosed {
47+
return nil
48+
}
49+
return err
50+
case <-sigCh:
51+
ctx, cancel := context.WithTimeout(context.Background(), ShutdownTimeout)
52+
defer cancel()
53+
54+
shutdownErr := shutdownServer(ctx, server)
55+
err := <-errCh
56+
if err == http.ErrServerClosed {
57+
err = nil
58+
}
59+
if shutdownErr != nil {
60+
return shutdownErr
61+
}
62+
return err
63+
}
1064
}

modkit/http/server_test.go

Lines changed: 146 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,36 @@
11
package http
22

33
import (
4+
"context"
45
"errors"
6+
"io"
7+
"net"
58
"net/http"
9+
"os"
10+
"runtime"
11+
"syscall"
612
"testing"
13+
"time"
714
)
815

9-
func TestServe_UsesHTTPServer(t *testing.T) {
16+
func TestServe_ReturnsErrorWhenServerFailsToStart(t *testing.T) {
1017
originalListen := listenAndServe
18+
originalShutdown := shutdownServer
1119
defer func() {
1220
listenAndServe = originalListen
21+
shutdownServer = originalShutdown
1322
}()
1423

1524
var gotAddr string
1625
var gotHandler http.Handler
17-
listenAndServe = func(addr string, handler http.Handler) error {
18-
gotAddr = addr
19-
gotHandler = handler
26+
listenAndServe = func(server *http.Server) error {
27+
gotAddr = server.Addr
28+
gotHandler = server.Handler
2029
return errors.New("boom")
2130
}
31+
shutdownServer = func(ctx context.Context, server *http.Server) error {
32+
return nil
33+
}
2234

2335
router := NewRouter()
2436
err := Serve("127.0.0.1:12345", router)
@@ -33,3 +45,133 @@ func TestServe_UsesHTTPServer(t *testing.T) {
3345
t.Fatalf("expected error from listenAndServe, got %v", err)
3446
}
3547
}
48+
49+
func TestServe_HandlesSignals_ReturnsNil(t *testing.T) {
50+
for _, tt := range []struct {
51+
name string
52+
sig os.Signal
53+
}{
54+
{name: "SIGINT", sig: os.Interrupt},
55+
{name: "SIGTERM", sig: syscall.SIGTERM},
56+
} {
57+
t.Run(tt.name, func(t *testing.T) {
58+
originalListen := listenAndServe
59+
originalShutdown := shutdownServer
60+
defer func() {
61+
listenAndServe = originalListen
62+
shutdownServer = originalShutdown
63+
}()
64+
65+
serveStarted := make(chan struct{})
66+
shutdownRequested := make(chan struct{})
67+
68+
listenAndServe = func(server *http.Server) error {
69+
close(serveStarted)
70+
<-shutdownRequested
71+
return http.ErrServerClosed
72+
}
73+
shutdownServer = func(ctx context.Context, server *http.Server) error {
74+
close(shutdownRequested)
75+
return nil
76+
}
77+
78+
errCh := make(chan error, 1)
79+
go func() {
80+
errCh <- Serve("127.0.0.1:12345", NewRouter())
81+
}()
82+
83+
<-serveStarted
84+
if tt.sig == syscall.SIGTERM && runtime.GOOS == "windows" {
85+
t.Skip("SIGTERM not supported on Windows")
86+
}
87+
proc, err := os.FindProcess(os.Getpid())
88+
if err != nil {
89+
t.Fatalf("failed to find process: %v", err)
90+
}
91+
if err := proc.Signal(tt.sig); err != nil {
92+
t.Fatalf("failed to send signal: %v", err)
93+
}
94+
95+
if err := <-errCh; err != nil {
96+
t.Fatalf("expected nil on clean shutdown, got %v", err)
97+
}
98+
})
99+
}
100+
}
101+
102+
func TestServe_ShutdownWaitsForInFlightRequest(t *testing.T) {
103+
originalListen := listenAndServe
104+
originalShutdown := shutdownServer
105+
originalTimeout := ShutdownTimeout
106+
defer func() {
107+
listenAndServe = originalListen
108+
shutdownServer = originalShutdown
109+
ShutdownTimeout = originalTimeout
110+
}()
111+
112+
ShutdownTimeout = 2 * time.Second
113+
114+
ln, err := net.Listen("tcp", "127.0.0.1:0")
115+
if err != nil {
116+
t.Fatalf("failed to listen: %v", err)
117+
}
118+
addr := ln.Addr().String()
119+
120+
requestStarted := make(chan struct{})
121+
releaseRequest := make(chan struct{})
122+
123+
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
124+
close(requestStarted)
125+
<-releaseRequest
126+
w.WriteHeader(http.StatusOK)
127+
})
128+
129+
listenAndServe = func(server *http.Server) error {
130+
return server.Serve(ln)
131+
}
132+
shutdownServer = func(ctx context.Context, server *http.Server) error {
133+
return server.Shutdown(ctx)
134+
}
135+
136+
serveErrCh := make(chan error, 1)
137+
go func() {
138+
serveErrCh <- Serve(addr, handler)
139+
}()
140+
141+
clientErrCh := make(chan error, 1)
142+
go func() {
143+
resp, err := http.Get("http://" + addr)
144+
if err != nil {
145+
clientErrCh <- err
146+
return
147+
}
148+
_, _ = io.Copy(io.Discard, resp.Body)
149+
_ = resp.Body.Close()
150+
clientErrCh <- nil
151+
}()
152+
153+
<-requestStarted
154+
155+
proc, err := os.FindProcess(os.Getpid())
156+
if err != nil {
157+
t.Fatalf("failed to find process: %v", err)
158+
}
159+
if err := proc.Signal(os.Interrupt); err != nil {
160+
t.Fatalf("failed to send signal: %v", err)
161+
}
162+
163+
select {
164+
case err := <-serveErrCh:
165+
t.Fatalf("expected Serve to wait for in-flight request, got %v", err)
166+
case <-time.After(200 * time.Millisecond):
167+
}
168+
169+
close(releaseRequest)
170+
171+
if err := <-serveErrCh; err != nil {
172+
t.Fatalf("expected nil on clean shutdown, got %v", err)
173+
}
174+
if err := <-clientErrCh; err != nil {
175+
t.Fatalf("request failed: %v", err)
176+
}
177+
}

0 commit comments

Comments
 (0)