forked from beeker1121/nosurfctx
-
Notifications
You must be signed in to change notification settings - Fork 0
/
middleware.go
128 lines (108 loc) · 3.53 KB
/
middleware.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
package nosurfctx
import (
"context"
"net/http"
"net/url"
)
// exemptMethods defines HTTP methods for which we only issue the CSRF token
// for, and do not try verifying.
var exemptMethods = []string{"GET", "HEAD", "OPTIONS", "TRACE"}
// defaultErrorHandler is the default error handler.
func defaultErrorHandler(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Bad Request", http.StatusBadRequest)
}
// Export the public error handler so it can be modified.
var DefaultErrorHandler = defaultErrorHandler
// Protect is the standard middleware used for protecting routes from CSRF
// attacks, taking into account the exempt HTTP methods.
func Protect(h http.HandlerFunc) http.HandlerFunc {
return protect(h, true)
}
// Add middleware to bypass CSRF protection for testing purpose
func NoProtect(h http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
r = r.WithContext(context.WithValue(r.Context(), csrfKey, "DUMMY"))
h(w, r)
}
}
// ForceProtect is middleware used for potecting routes from CSRF attacks,
// disregarding the exempt HTTP methods.
//
// This, for instance, can be used to protect GET requests sent via AJAX.
func ForceProtect(h http.HandlerFunc) http.HandlerFunc {
return protect(h, false)
}
// protect is the middleware used for protecting routes from CSRF attacks.
func protect(h http.HandlerFunc, checkExempt bool) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Set Vary header to prevent cookie caching in some browsers.
w.Header().Add("Vary", "Cookie")
// Try to get the real token from the CSRF cookie.
realToken := getTokenFromCookie(r)
// If the length of the real token does not match tokenLength,
// it has either been tampered with, or we're migrating onto a
// new algorithm, or it hasn't been set yet.
// In any case, we should generate a new one and set it to the
// cookie and context.
// If the real token already exists in the cookie and matches
// tokenLength, we can just set it to the context.
if len(realToken) != tokenLength {
token, err := generateToken()
if err != nil {
DefaultErrorHandler(w, r)
return
}
setTokenCookie(w, token)
r, err = setTokenContext(r, token)
if err != nil {
DefaultErrorHandler(w, r)
return
}
} else {
// Create err variable to prevent overwrite of r.
var err error
r, err = setTokenContext(r, realToken)
if err != nil {
DefaultErrorHandler(w, r)
return
}
}
// Skip to the success handler if the request method is
// exempt from CSRF verification.
if checkExempt && stringInSlice(r.Method, exemptMethods) {
h(w, r)
return
}
// If the request is secure, we enforce origin check
// for referrer to prevent MITM of http->https requests.
if r.URL.Scheme == "https" {
referrer, err := url.Parse(r.Header.Get("Referrer"))
// If we can't parse the referrer or it's empty,
// we assume it's not specified.
if err != nil || referrer.String() == "" {
DefaultErrorHandler(w, r)
return
}
// If the referrer doesn't share origin with the request.
// URL, send a Bad Request error.
if !sameOrigin(referrer, r.URL) {
DefaultErrorHandler(w, r)
return
}
}
// Try to get the sent token from the request.
sentToken := getTokenFromRequest(r)
// Verify the token.
tokenOk, err := verifyToken(realToken, sentToken)
if err != nil {
DefaultErrorHandler(w, r)
return
}
if !tokenOk {
DefaultErrorHandler(w, r)
return
}
// Everything passed, call the next handler.
h(w, r)
}
}