-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmiddleware_patch.go
More file actions
148 lines (120 loc) · 4.79 KB
/
middleware_patch.go
File metadata and controls
148 lines (120 loc) · 4.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
package turtleware
import (
"github.com/rs/zerolog"
"context"
"encoding/json"
"errors"
"net/http"
"time"
)
var (
// ErrUnmodifiedSinceHeaderMissing is returned when the If-Unmodified-Since header is missing.
ErrUnmodifiedSinceHeaderMissing = errors.New("If-Unmodified-Since header missing")
// ErrUnmodifiedSinceHeaderInvalid is returned when the If-Unmodified-Since header is in an invalid format.
ErrUnmodifiedSinceHeaderInvalid = errors.New("received If-Unmodified-Since header in invalid format")
// ErrNoChanges is returned when the patch request did not contain any changes.
ErrNoChanges = errors.New("patch request did not contain any changes")
// ErrNoDateTimeLayoutMatched is returned when the If-Unmodified-Since header does not match any known date time layout.
ErrNoDateTimeLayoutMatched = errors.New("no date time layout matched")
)
// PatchFunc is a function called for delegating the actual updating of an existing resource.
type PatchFunc[T PatchDTO] func(ctx context.Context, entityUUID, userUUID string, patch T, ifUnmodifiedSince time.Time) error
// PatchDTO defines the contract for validating a DTO used for patching a new resource.
type PatchDTO interface {
HasChanges() bool
Validate() []error
}
// IsHandledByDefaultPatchErrorHandler indicates if the DefaultPatchErrorHandler has any special
// handling for the given error, or if it defaults to handing it out as-is.
func IsHandledByDefaultPatchErrorHandler(err error) bool {
return errors.Is(err, ErrUnmodifiedSinceHeaderInvalid) ||
errors.Is(err, ErrNoChanges) ||
errors.Is(err, ErrUnmodifiedSinceHeaderMissing) ||
IsHandledByDefaultErrorHandler(err)
}
// DefaultPatchErrorHandler is a default error handler, which sensibly handles errors known by turtleware.
func DefaultPatchErrorHandler(ctx context.Context, w http.ResponseWriter, r *http.Request, err error) {
if errors.Is(err, ErrUnmodifiedSinceHeaderInvalid) || errors.Is(err, ErrNoChanges) {
WriteError(ctx, w, r, http.StatusBadRequest, err)
return
}
if errors.Is(err, ErrUnmodifiedSinceHeaderMissing) {
WriteError(ctx, w, r, http.StatusPreconditionRequired, err)
return
}
DefaultErrorHandler(ctx, w, r, err)
}
// ResourcePatchMiddleware is a middleware for patching or updating an existing resource.
// It parses a PatchDTO from the request body, validates it, and then calls the provided PatchFunc.
// Errors encountered during the process are passed to the provided ErrorHandlerFunc.
func ResourcePatchMiddleware[T PatchDTO](patchFunc PatchFunc[T], errorHandler ErrorHandlerFunc) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
patchContext, cancel := context.WithCancel(r.Context())
defer cancel()
logger := zerolog.Ctx(patchContext)
userUUID, err := UserUUIDFromRequestContext(patchContext)
if err != nil {
errorHandler(patchContext, w, r, err)
return
}
entityUUID, err := EntityUUIDFromRequestContext(patchContext)
if err != nil {
errorHandler(patchContext, w, r, err)
return
}
// ----------------
var patch T
if err := json.NewDecoder(r.Body).Decode(&patch); err != nil {
errorHandler(patchContext, w, r, ErrMarshalling)
return
}
if !patch.HasChanges() {
errorHandler(patchContext, w, r, ErrNoChanges)
return
}
if validationErrors := patch.Validate(); len(validationErrors) > 0 {
errorHandler(patchContext, w, r, &ValidationWrapperError{validationErrors})
return
}
ifUnmodifiedSince, err := GetIfUnmodifiedSince(r)
if err != nil {
errorHandler(patchContext, w, r, err)
return
}
if err := patchFunc(patchContext, entityUUID, userUUID, patch, ifUnmodifiedSince); err != nil {
logger.Error().Err(err).Msg("Patch failed")
errorHandler(patchContext, w, r, err)
return
}
if next != nil {
next.ServeHTTP(w, r)
}
})
}
}
// GetIfUnmodifiedSince tries to parse a time.Time from the If-Unmodified-Since header of
// a given request. It tries the following formats (in that order):
//
// - time.RFC1123
// - time.RFC3339Nano
// - time.RFC3339
func GetIfUnmodifiedSince(r *http.Request) (time.Time, error) {
ifUnmodifiedSinceHeader := r.Header.Get("If-Unmodified-Since")
if ifUnmodifiedSinceHeader == "" {
return time.Time{}, ErrUnmodifiedSinceHeaderMissing
}
ifUnmodifiedSince, err := parseTimeByFormats(ifUnmodifiedSinceHeader, time.RFC1123, time.RFC3339Nano, time.RFC3339)
if err != nil {
return time.Time{}, errors.Join(ErrUnmodifiedSinceHeaderInvalid, err)
}
return ifUnmodifiedSince, nil
}
func parseTimeByFormats(value string, layouts ...string) (time.Time, error) {
for _, layout := range layouts {
if parsedValue, err := time.Parse(layout, value); err == nil {
return parsedValue, nil
}
}
return time.Time{}, ErrNoDateTimeLayoutMatched
}