Skip to content

Commit f58e1db

Browse files
authored
Merge pull request #11 from runreveal/rpc
Experimental RPC Framework
2 parents 35351c0 + 36f935f commit f58e1db

File tree

16 files changed

+1258
-1
lines changed

16 files changed

+1258
-1
lines changed

loader/loader.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ type Registry[T any] struct {
4242
}
4343

4444
var registry = struct {
45-
// map[type]registry[T] where T is variadic
45+
// map[typeString]registry[T] where T is variadic and typeString is:
46+
// reflect.TypeOf(T).String()
4647
m map[string]any
4748
sync.RWMutex
4849
}{

rpc/README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
2+
RPC is a framework for easily creating HTTP APIs in Go.
3+
4+
It handles serialization, validation, documentation and client generation.
5+
6+
It has escape hatches to access the standard Go HTTP request and response
7+
objects.
8+

rpc/context.go

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
package rpc
2+
3+
import (
4+
"context"
5+
"net/http"
6+
)
7+
8+
type contextKey struct{ name string }
9+
10+
var (
11+
reqContextKey = contextKey{name: "requestKey"}
12+
respContextKey = contextKey{name: "responseKey"}
13+
)
14+
15+
func Request(ctx context.Context) *http.Request {
16+
v, ok := ctx.Value(reqContextKey).(*http.Request)
17+
if !ok {
18+
panic("request not set on context. ensure handler is wrapped")
19+
}
20+
return v
21+
}
22+
23+
func ResponseWriter(ctx context.Context) http.ResponseWriter {
24+
v, ok := ctx.Value(respContextKey).(*responseWrapper)
25+
if !ok {
26+
panic("response not set on context. ensure handler is wrapped")
27+
}
28+
return v
29+
}

rpc/errors.go

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
package rpc
2+
3+
import (
4+
"context"
5+
"database/sql"
6+
"encoding/json"
7+
"errors"
8+
"fmt"
9+
"log/slog"
10+
"net/http"
11+
"runtime"
12+
"time"
13+
)
14+
15+
type CustomError interface {
16+
error
17+
Status() int
18+
Format(context.Context) any
19+
}
20+
21+
var (
22+
errorRegistry = []CustomError{}
23+
)
24+
25+
func RegisterErrorHandler(ce CustomError) {
26+
fmt.Printf("registering error handler: %p, %I, %v\n", &ce, ce, ce)
27+
errorRegistry = append(errorRegistry, ce)
28+
}
29+
30+
// HandleErr is equivlant to errResp but using more flexible error types
31+
func handleErr(ctx context.Context, w *responseWrapper, err error) {
32+
errorHelper("error: %v\n", err)
33+
34+
if w.status != 0 {
35+
slog.Error("response sent before error was handled")
36+
return
37+
}
38+
39+
e := json.NewEncoder(w)
40+
// Custom errors take precedence over default errors
41+
for _, ce := range errorRegistry {
42+
if errors.As(err, &ce) {
43+
w.WriteHeader(ce.Status())
44+
encErr := e.Encode(ce.Format(ctx))
45+
if encErr != nil {
46+
errorHelper("error encountered encoding error response: %v", encErr)
47+
}
48+
return
49+
}
50+
}
51+
52+
var (
53+
ue UserError
54+
ae AuthError
55+
le LimitError
56+
encErr error
57+
)
58+
switch {
59+
case errors.As(err, &ae):
60+
w.WriteHeader(http.StatusUnauthorized)
61+
encErr = e.Encode(RPCResponse{Error: ae.AuthError()})
62+
63+
case errors.As(err, &ue):
64+
w.WriteHeader(http.StatusBadRequest)
65+
encErr = e.Encode(RPCResponse{Error: ue.UserError()})
66+
67+
case errors.Is(err, context.Canceled):
68+
w.WriteHeader(http.StatusServiceUnavailable)
69+
encErr = e.Encode(RPCResponse{Error: "context canceled"})
70+
71+
case errors.Is(err, sql.ErrNoRows):
72+
w.WriteHeader(http.StatusBadRequest)
73+
encErr = e.Encode(RPCResponse{Error: "no record could be found"})
74+
75+
case errors.As(err, &le):
76+
w.WriteHeader(http.StatusUpgradeRequired)
77+
encErr = e.Encode(RPCResponse{Error: le.LimitError()})
78+
79+
default:
80+
w.WriteHeader(http.StatusInternalServerError)
81+
encErr = e.Encode(RPCResponse{Error: "unknown error"})
82+
}
83+
if encErr != nil {
84+
errorHelper("error encountered encoding error response: %v", encErr)
85+
}
86+
}
87+
88+
type UserError interface {
89+
UserError() string
90+
}
91+
92+
type userErr struct {
93+
err error
94+
}
95+
96+
func UserErr(err error) error {
97+
if err == nil {
98+
return nil
99+
}
100+
return userErr{
101+
err: err,
102+
}
103+
}
104+
105+
func (ue userErr) UserError() string {
106+
return ue.err.Error()
107+
}
108+
109+
func (ue userErr) Error() string {
110+
return ue.err.Error()
111+
}
112+
113+
//////////////////////
114+
115+
type AuthError interface {
116+
AuthError() string
117+
}
118+
119+
type authErr struct {
120+
err error
121+
}
122+
123+
func AuthErr(err error) error {
124+
if err == nil {
125+
return nil
126+
}
127+
return authErr{
128+
err: err,
129+
}
130+
}
131+
132+
func (ae authErr) AuthError() string {
133+
return ae.err.Error()
134+
}
135+
136+
func (ae authErr) Error() string {
137+
return ae.err.Error()
138+
}
139+
140+
//////////////////
141+
142+
type LimitError interface {
143+
LimitError() string
144+
}
145+
146+
type limitErr struct {
147+
err error
148+
}
149+
150+
func LimitErr(err error) error {
151+
if err == nil {
152+
return nil
153+
}
154+
return limitErr{
155+
err: err,
156+
}
157+
}
158+
159+
func (le limitErr) Error() string {
160+
return le.err.Error()
161+
}
162+
163+
func (le limitErr) LimitError() string {
164+
return le.err.Error()
165+
}
166+
167+
var (
168+
ErrLimitReached = AuthErr(errors.New("limit reached"))
169+
)
170+
171+
type ErrVersionMismatch struct {
172+
Err error
173+
ClientVersion string
174+
ServerVersion string
175+
}
176+
177+
func (e ErrVersionMismatch) Error() string {
178+
// Don't mask the wrapped error if it's rewrapped
179+
return e.Err.Error()
180+
}
181+
182+
func (e ErrVersionMismatch) Warning() string {
183+
return fmt.Sprintf("client/server version mismatch (c: %s s: %s).", e.ClientVersion, e.ServerVersion)
184+
}
185+
186+
func (e ErrVersionMismatch) Unwrap() error {
187+
return e.Err
188+
}
189+
190+
// errorHelper is a helper function to log errors to the default logger and
191+
// include the file and line number of the function calling HandleErr
192+
// we may want to inject the logger in the future
193+
func errorHelper(format string, args ...any) {
194+
l := slog.Default()
195+
if !l.Enabled(context.Background(), slog.LevelInfo) {
196+
return
197+
}
198+
var pcs [1]uintptr
199+
runtime.Callers(3, pcs[:]) // skip [Callers, errorHelper, HandleErr]
200+
r := slog.NewRecord(time.Now(), slog.LevelError, fmt.Sprintf(format, args...), pcs[0])
201+
_ = l.Handler().Handle(context.Background(), r)
202+
}

0 commit comments

Comments
 (0)